File size: 4,478 Bytes
972a35a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
# MIT License

# Copyright (c) [2023] [Anima-Lab]


from argparse import ArgumentParser
import os
import json
from omegaconf import OmegaConf

import torch
from models.maskdit import Precond_models

from sample import generate_with_net
from utils import parse_float_none, parse_int_list, init_processes


def generate(args):
    rank = args.global_rank
    size = args.global_size
    config = OmegaConf.load(args.config)
    label_dict = json.load(open(args.label_dict, 'r'))
    class_label = label_dict[str(args.class_idx)][1]
    print(f'start sampling class {class_label}...')
    device = torch.device('cuda')
    # setup directory
    sample_dir = os.path.join(args.results_dir, class_label)
    os.makedirs(sample_dir, exist_ok=True)
    args.outdir = sample_dir
    # setup model
    model = Precond_models[config.model.precond](
        img_resolution=config.model.in_size,
        img_channels=config.model.in_channels,
        num_classes=config.model.num_classes,
        model_type=config.model.model_type,
        use_decoder=config.model.use_decoder,
        mae_loss_coef=config.model.mae_loss_coef,
        pad_cls_token=config.model.pad_cls_token,
        use_encoder_feat=config.model.self_cond,
    ).to(device)

    model.eval()
    print(f"{config.model.model_type} ((use_decoder: {config.model.use_decoder})) Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f'extras: {model.model.extras}, cls_token: {model.model.cls_token}')

    model = torch.compile(model)
    ckpt = torch.load(args.ckpt_path, map_location=device)
    model.load_state_dict(ckpt['ema'])
    generate_with_net(args, model, device, rank, size)

    print(f'sampling class {class_label} done!')


if __name__ == '__main__':
    parser = ArgumentParser('Sample from a trained model')
    # basic config
    parser.add_argument('--config', type=str, required=True, help='path to config file')
    parser.add_argument('--label_dict', type=str, default='assets/imagenet_label.json', help='path to label dict')
    parser.add_argument("--results_dir", type=str, default="samples", help='path to save samples')
    parser.add_argument('--ckpt_path', type=str, default=None, help='path to ckpt')

    # sampling
    parser.add_argument('--seeds', type=parse_int_list, default='100-131', help='Random seeds (e.g. 1,2,5-10)')
    parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
    parser.add_argument('--class_idx', type=int, default=None, help='Class label  [default: random]')
    parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')

    parser.add_argument('--num_steps', type=int, default=40, help='Number of sampling steps')
    parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
    parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
    parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'], help='Ablate ODE solver')
    parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'], help='Ablate noise schedule sigma(t)')
    parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
    parser.add_argument('--pretrained_path', type=str, default='assets/autoencoder_kl.pth', help='Autoencoder ckpt')

    parser.add_argument('--max_batch_size', type=int, default=32, help='Maximum batch size per GPU during sampling')
    parser.add_argument('--num_expected', type=int, default=32, help='Number of images to use')
    parser.add_argument("--global_seed", type=int, default=0)
    parser.add_argument('--fid_batch_size', type=int, default=32, help='Maximum batch size')

    # ddp 
    parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
    parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
    parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
    parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
    parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
    args = parser.parse_args()
    args.global_rank = 0
    args.local_rank = 0
    args.global_size = 1
    init_processes(generate, args)