MaskDiT / sample.py
devzhk
Add model files
972a35a
# MIT License
# Copyright (c) [2023] [Anima-Lab]
# This code is adapted from https://github.com/NVlabs/edm/blob/main/generate.py.
# The original code is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License, which is can be found at licenses/LICENSE_EDM.txt.
import argparse
import random
import PIL.Image
import lmdb
import numpy as np
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
from tqdm import tqdm
from models.maskdit import Precond_models, DiT_models
from utils import *
import autoencoder
# ----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).
def edm_sampler(
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Time step discretization.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = net.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
# Euler step.
denoised = net(x_hat.float(), t_hat, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur
# Apply 2nd order correction.
if i < num_steps - 1:
denoised = net(x_next.float(), t_next, class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
return x_next
# ----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.
def ablation_sampler(
net, latents, class_labels=None, cfg_scale=None, feat=None, randn_like=torch.randn_like,
num_steps=18, sigma_min=None, sigma_max=None, rho=7,
solver='heun', discretization='edm', schedule='linear', scaling='none',
epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
assert solver in ['euler', 'heun']
assert discretization in ['vp', 've', 'iddpm', 'edm']
assert schedule in ['vp', 've', 'linear']
assert scaling in ['vp', 'none']
# Helper functions for VP & VE noise level schedules.
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (
sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
ve_sigma = lambda t: t.sqrt()
ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
ve_sigma_inv = lambda sigma: sigma ** 2
# Select default noise level range based on the specified time step discretization.
if sigma_min is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
if sigma_max is None:
vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]
# Adjust noise levels based on what's supported by the network.
sigma_min = max(sigma_min, net.sigma_min)
sigma_max = min(sigma_max, net.sigma_max)
# Compute corresponding betas for VP.
vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d
# Define time steps in terms of noise level.
step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
if discretization == 'vp':
orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
elif discretization == 've':
orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
sigma_steps = ve_sigma(orig_t_steps)
elif discretization == 'iddpm':
u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
else:
assert discretization == 'edm'
sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
# Define noise level schedule.
if schedule == 'vp':
sigma = vp_sigma(vp_beta_d, vp_beta_min)
sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
elif schedule == 've':
sigma = ve_sigma
sigma_deriv = ve_sigma_deriv
sigma_inv = ve_sigma_inv
else:
assert schedule == 'linear'
sigma = lambda t: t
sigma_deriv = lambda t: 1
sigma_inv = lambda sigma: sigma
# Define scaling schedule.
if scaling == 'vp':
s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
else:
assert scaling == 'none'
s = lambda t: 1
s_deriv = lambda t: 0
# Compute final time steps based on the corresponding noise levels.
t_steps = sigma_inv(net.round_sigma(sigma_steps))
t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0
# Main sampling loop.
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next
# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(
t_hat) * S_noise * randn_like(x_cur)
# Euler step.
h = t_next - t_hat
denoised = net(x_hat.float() / s(t_hat), sigma(t_hat), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(
t_hat) / sigma(t_hat) * denoised
x_prime = x_hat + alpha * h * d_cur
t_prime = t_hat + alpha * h
# Apply 2nd order correction.
if solver == 'euler' or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
assert solver == 'heun'
denoised = net(x_prime.float() / s(t_prime), sigma(t_prime), class_labels, cfg_scale, feat=feat)['x'].to(torch.float64)
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(
t_prime) * s(t_prime) / sigma(t_prime) * denoised
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)
return x_next
# ----------------------------------------------------------------------------
def retrieve_n_features(batch_size, feat_path, feat_dim, num_classes, device, split='train', sample_mode='rand_full'):
env = lmdb.open(os.path.join(feat_path, split), readonly=True, lock=False, create=False)
# Start a new read transaction
with env.begin() as txn:
# Read all images in one single transaction, with one lock
# We could split this up into multiple transactions if needed
length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
if sample_mode == 'rand_full':
image_ids = random.sample(range(length // 2), batch_size)
image_ids_y = image_ids
elif sample_mode == 'rand_repeat':
image_ids = random.sample(range(length // 2), 1) * batch_size
image_ids_y = image_ids
elif sample_mode == 'rand_y':
image_ids = random.sample(range(length // 2), 1) * batch_size
image_ids_y = random.sample(range(length // 2), batch_size)
else:
raise NotImplementedError
features, labels = [], []
for image_id, image_id_y in zip(image_ids, image_ids_y):
feat_bytes = txn.get(f'feat-{str(image_id)}'.encode('utf-8'))
y_bytes = txn.get(f'y-{str(image_id_y)}'.encode('utf-8'))
feat = np.frombuffer(feat_bytes, dtype=np.float32).reshape([feat_dim]).copy()
y = int(y_bytes.decode('utf-8'))
features.append(feat)
labels.append(y)
features = torch.from_numpy(np.stack(features)).to(device)
labels = torch.from_numpy(np.array(labels)).to(device)
class_labels = torch.zeros([batch_size, num_classes], device=device)
if num_classes > 0:
class_labels = torch.eye(num_classes, device=device)[labels]
assert features.shape[0] == class_labels.shape[0] == batch_size
return features, class_labels
@torch.no_grad()
def generate_with_net(args, net, device, rank, size):
seeds = args.seeds
num_batches = ((len(seeds) - 1) // (args.max_batch_size * size) + 1) * size
all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
rank_batches = all_batches[rank:: size]
net.eval()
# Setup sampler
sampler_kwargs = dict(num_steps=args.num_steps, S_churn=args.S_churn,
solver=args.solver, discretization=args.discretization,
schedule=args.schedule, scaling=args.scaling)
sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
mprint(f"sampler_kwargs: {sampler_kwargs}, \nsampler fn: {sampler_fn.__name__}")
# Setup autoencoder
vae = autoencoder.get_model(args.pretrained_path).to(device)
# generate images
mprint(f'Generating {len(seeds)} images to "{args.outdir}"...')
for batch_seeds in tqdm(rank_batches, unit='batch', disable=(rank != 0)):
dist.barrier()
batch_size = len(batch_seeds)
if batch_size == 0:
continue
# Pick latents and labels.
rnd = StackedRandomGenerator(device, batch_seeds)
latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device)
class_labels = torch.zeros([batch_size, net.num_classes], device=device)
if net.num_classes:
class_labels = torch.eye(net.num_classes, device=device)[
rnd.randint(net.num_classes, size=[batch_size], device=device)]
if args.class_idx is not None:
class_labels[:, :] = 0
class_labels[:, args.class_idx] = 1
# retrieve features from training set [support random only]
feat = None
# Generate images.
def recur_decode(z):
try:
return vae.decode(z)
except: # reduce the batch for vae decoder but two forward passes when OOM happens occasionally
assert z.shape[2] % 2 == 0
z1, z2 = z.tensor_split(2)
return torch.cat([recur_decode(z1), recur_decode(z2)])
with torch.no_grad():
z = sampler_fn(net, latents.float(), class_labels.float(), randn_like=rnd.randn_like,
cfg_scale=args.cfg_scale, feat=feat, **sampler_kwargs).float()
images = recur_decode(z)
# Save images.
images_np = images.add_(1).mul(127.5).clamp_(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
# images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()
for seed, image_np in zip(batch_seeds, images_np):
image_dir = os.path.join(args.outdir, f'{seed - seed % 1000:06d}') if args.subdirs else args.outdir
os.makedirs(image_dir, exist_ok=True)
image_path = os.path.join(image_dir, f'{seed:06d}.png')
if image_np.shape[2] == 1:
PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
else:
PIL.Image.fromarray(image_np, 'RGB').save(image_path)
def generate(args):
device = torch.device("cuda")
mprint(f'cf_scale: {args.cfg_scale}')
if args.global_rank == 0:
os.makedirs(args.outdir, exist_ok=True)
logger = Logger(file_name=f'{args.outdir}/log.txt', file_mode="a+", should_flush=True)
# Create model:
net = Precond_models[args.precond](
img_resolution=args.image_size,
img_channels=args.image_channels,
num_classes=args.num_classes,
model_type=args.model_type,
use_decoder=args.use_decoder,
mae_loss_coef=args.mae_loss_coef,
pad_cls_token=args.pad_cls_token,
ext_feature_dim=args.ext_feature_dim
).to(device)
mprint(
f"{args.model_type} (use_decoder: {args.use_decoder}) Model Parameters: {sum(p.numel() for p in net.parameters()):,}")
# Load checkpoints
ckpt = torch.load(args.ckpt_path, map_location=device)
net.load_state_dict(ckpt['ema'])
mprint(f'Load weights from {args.ckpt_path}')
generate_with_net(args, net, device)
# Done.
cleanup()
if args.global_rank == 0:
logger.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser('sampling parameters')
# ddp
parser.add_argument('--num_proc_node', type=int, default=1, help='The number of nodes in multi node env.')
parser.add_argument('--num_process_per_node', type=int, default=1, help='number of gpus')
parser.add_argument('--node_rank', type=int, default=0, help='The index of node.')
parser.add_argument('--local_rank', type=int, default=0, help='rank of process in the node')
parser.add_argument('--master_address', type=str, default='localhost', help='address for master')
# sampling
parser.add_argument("--feat_path", type=str, default='')
parser.add_argument("--ext_feature_dim", type=int, default=0)
parser.add_argument('--ckpt_path', type=str, required=True, help='Network pickle filename')
parser.add_argument('--outdir', type=str, required=True, help='sampling results save filename')
parser.add_argument('--seeds', type=parse_int_list, default='0-63', help='Random seeds (e.g. 1,2,5-10)')
parser.add_argument('--subdirs', action='store_true', help='Create subdirectory for every 1000 seeds')
parser.add_argument('--class_idx', type=int, default=None, help='Class label [default: random]')
parser.add_argument('--max_batch_size', type=int, default=64, help='Maximum batch size per GPU')
parser.add_argument("--cfg_scale", type=parse_float_none, default=None, help='None = no guidance, by default = 4.0')
parser.add_argument('--num_steps', type=int, default=18, help='Number of sampling steps')
parser.add_argument('--S_churn', type=int, default=0, help='Stochasticity strength')
parser.add_argument('--solver', type=str, default=None, choices=['euler', 'heun'], help='Ablate ODE solver')
parser.add_argument('--discretization', type=str, default=None, choices=['vp', 've', 'iddpm', 'edm'],
help='Ablate ODE solver')
parser.add_argument('--schedule', type=str, default=None, choices=['vp', 've', 'linear'],
help='Ablate noise schedule sigma(t)')
parser.add_argument('--scaling', type=str, default=None, choices=['vp', 'none'], help='Ablate signal scaling s(t)')
parser.add_argument('--pretrained_path', type=str, default='assets/stable_diffusion/autoencoder_kl.pth',
help='Autoencoder ckpt')
# model
parser.add_argument("--image_size", type=int, default=32)
parser.add_argument("--image_channels", type=int, default=4)
parser.add_argument("--num_classes", type=int, default=1000, help='0 means unconditional')
parser.add_argument("--model_type", type=str, choices=list(DiT_models.keys()), default="DiT-XL/2")
parser.add_argument('--precond', type=str, choices=['vp', 've', 'edm'], default='edm', help='precond train & loss')
parser.add_argument("--use_decoder", type=str2bool, default=False)
parser.add_argument("--pad_cls_token", type=str2bool, default=False)
parser.add_argument('--mae_loss_coef', type=float, default=0, help='0 means no MAE loss')
parser.add_argument('--sample_mode', type=str, default='rand_full', help='[rand_full, rand_repeat]')
args = parser.parse_args()
args.global_size = args.num_proc_node * args.num_process_per_node
size = args.num_process_per_node
if size > 1:
processes = []
for rank in range(size):
args.local_rank = rank
args.global_rank = rank + args.node_rank * args.num_process_per_node
p = Process(target=init_processes, args=(generate, args))
p.start()
processes.append(p)
for p in processes:
p.join()
else:
print('Single GPU run')
assert args.global_size == 1 and args.local_rank == 0
args.global_rank = 0
init_processes(generate, args)