|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from argparse import ArgumentParser |
|
|
import os |
|
|
import json |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
import torch |
|
|
from models.maskdit import Precond_models |
|
|
|
|
|
from sample import generate_with_net |
|
|
from utils import parse_float_none, parse_int_list, init_processes |
|
|
|
|
|
|
|
|
def generate(args): |
|
|
rank = args.global_rank |
|
|
size = args.global_size |
|
|
config = OmegaConf.load(args.config) |
|
|
label_dict = json.load(open(args.label_dict, 'r')) |
|
|
class_label = label_dict[str(args.class_idx)][1] |
|
|
print(f'start sampling class {class_label}...') |
|
|
device = torch.device('cuda') |
|
|
|
|
|
sample_dir = os.path.join(args.results_dir, class_label) |
|
|
os.makedirs(sample_dir, exist_ok=True) |
|
|
args.outdir = sample_dir |
|
|
|
|
|
model = Precond_models[config.model.precond]( |
|
|
img_resolution=config.model.in_size, |
|
|
img_channels=config.model.in_channels, |
|
|
num_classes=config.model.num_classes, |
|
|
model_type=config.model.model_type, |
|
|
use_decoder=config.model.use_decoder, |
|
|
mae_loss_coef=config.model.mae_loss_coef, |
|
|
pad_cls_token=config.model.pad_cls_token, |
|
|
use_encoder_feat=config.model.self_cond, |
|
|
).to(device) |
|
|
|
|
|
model.eval() |
|
|
print(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}") |
|
|
print(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}') |
|
|
|
|
|
model = torch.compile(model) |
|
|
ckpt = torch.load(args.ckpt_path, map_location=device) |
|
|
model.load_state_dict(ckpt['ema']) |
|
|
generate_with_net(args, model, device, rank, size) |
|
|
|
|
|
print(f'sampling class {class_label} done!') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
parser = ArgumentParser('Sample from a trained model') |
|
|
|
|
|
parser.add_argument('--config', type=str, required=True, help='path to config file') |
|
|
parser.add_argument('--label_dict', type=str, default='assets/imagenet_label.json', help='path to label dict') |
|
|
parser.add_argument("--results_dir", type=str, default="samples", help='path to save samples') |
|
|
parser.add_argument('--ckpt_path', type=str, default=None, help='path to ckpt') |
|
|
|
|
|
|
|
|
parser.add_argument('--seeds', type=parse_int_list, default='100-131', help='Random seeds (e.g. 1,2,5-10)') |
|
|
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds') |
|
|
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]') |
|
|
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0') |
|
|
|
|
|
parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps') |
|
|
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength') |
|
|
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver') |
|
|
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver') |
|
|
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)') |
|
|
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)') |
|
|
parser.add_argument('--pretrained_path', type=str, default='assets/autoencoder_kl.pth', help='Autoencoder ckpt') |
|
|
|
|
|
parser.add_argument('--max_batch_size', type=int, default=32, help='Maximum batch size per GPU during sampling') |
|
|
parser.add_argument('--num_expected', type=int, default=32, help='Number of images to use') |
|
|
parser.add_argument("--global_seed", type=int, default=0) |
|
|
parser.add_argument('--fid_batch_size', type=int, default=32, help='Maximum batch size') |
|
|
|
|
|
|
|
|
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.') |
|
|
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus') |
|
|
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.') |
|
|
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node') |
|
|
parser.add_argument('--master_address', type=str, default='localhost', help='address for master') |
|
|
args = parser.parse_args() |
|
|
args.global_rank = 0 |
|
|
args.local_rank = 0 |
|
|
args.global_size = 1 |
|
|
init_processes(generate, args) |