Spaces:
Build error
Build error
modify phrase
Browse files- app.py +19 -4
- stable_diffusion.py +103 -30
app.py
CHANGED
|
@@ -1,14 +1,14 @@
|
|
| 1 |
-
from stable_diffusion import
|
| 2 |
import gradio as gr
|
| 3 |
args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
|
| 4 |
-
model =
|
| 5 |
def text2img_output(phrase):
|
| 6 |
return model(phrase)
|
| 7 |
|
| 8 |
readme = open("me.md","rb+").read().decode("utf-8")
|
| 9 |
|
| 10 |
phrase = gr.components.Textbox(
|
| 11 |
-
value="
|
| 12 |
text2img_out = gr.components.Image(type="numpy")
|
| 13 |
|
| 14 |
instance = gr.Blocks()
|
|
@@ -22,4 +22,19 @@ with instance:
|
|
| 22 |
gr.Markdown(readme)
|
| 23 |
|
| 24 |
|
| 25 |
-
instance.queue(concurrency_count=20).launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stable_diffusion import Generate2img, Args
|
| 2 |
import gradio as gr
|
| 3 |
args = Args("", 5, None, 7.5, 512, 512, 443, "cpu", "./mdjrny-v4.pt")
|
| 4 |
+
model = Generate2img.instance(args)
|
| 5 |
def text2img_output(phrase):
|
| 6 |
return model(phrase)
|
| 7 |
|
| 8 |
readme = open("me.md","rb+").read().decode("utf-8")
|
| 9 |
|
| 10 |
phrase = gr.components.Textbox(
|
| 11 |
+
value="anthropomorphic cat portrait art")
|
| 12 |
text2img_out = gr.components.Image(type="numpy")
|
| 13 |
|
| 14 |
instance = gr.Blocks()
|
|
|
|
| 22 |
gr.Markdown(readme)
|
| 23 |
|
| 24 |
|
| 25 |
+
instance.queue(concurrency_count=20).launch(share=False)
|
| 26 |
+
#
|
| 27 |
+
#
|
| 28 |
+
# 1) anthropomorphic cat portrait art
|
| 29 |
+
#
|
| 30 |
+
# 
|
| 31 |
+
#
|
| 32 |
+
# 2) anthropomorphic cat portrait art(mdjrny-v4.pt)
|
| 33 |
+
#
|
| 34 |
+
# 
|
| 35 |
+
#
|
| 36 |
+
# 3) Kung Fu Panda(weight: wd-1-3-penultimate-ucg-cont.pt, steps:50)
|
| 37 |
+
#
|
| 38 |
+
# 
|
| 39 |
+
# 
|
| 40 |
+
#
|
stable_diffusion.py
CHANGED
|
@@ -12,13 +12,13 @@ from collections import namedtuple
|
|
| 12 |
import numpy as np
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
-
|
| 16 |
-
from torch.nn import Conv2d, Linear, Module,SiLU, UpsamplingNearest2d,ModuleList
|
| 17 |
from torch import Tensor
|
| 18 |
from torch.nn import functional as F
|
| 19 |
from torch.nn.parameter import Parameter
|
| 20 |
|
| 21 |
-
device = "
|
| 22 |
|
| 23 |
def apply_seq(seqs, x):
|
| 24 |
for seq in seqs:
|
|
@@ -31,6 +31,12 @@ def gelu(self):
|
|
| 31 |
def quick_gelu(x):
|
| 32 |
return x * torch.sigmoid(x * 1.702)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
class Normalize(Module):
|
| 35 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
| 36 |
super(Normalize, self).__init__()
|
|
@@ -166,13 +172,13 @@ class Encoder(Module):
|
|
| 166 |
self.down = ModuleList([
|
| 167 |
ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
|
| 168 |
ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
|
| 169 |
-
Conv2d(128, 128, 3, stride=2, padding=(0,
|
| 170 |
ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
|
| 171 |
ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
|
| 172 |
-
Conv2d(256, 256, 3, stride=2, padding=(0,
|
| 173 |
ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
|
| 174 |
ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
|
| 175 |
-
Conv2d(512, 512, 3, stride=2, padding=(0,
|
| 176 |
ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
|
| 177 |
ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
|
| 178 |
])
|
|
@@ -181,12 +187,17 @@ class Encoder(Module):
|
|
| 181 |
self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
|
| 182 |
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
| 183 |
self.name = name
|
|
|
|
| 184 |
|
| 185 |
def forward(self, x):
|
| 186 |
x = self.conv_in(x)
|
| 187 |
|
| 188 |
for l in self.down:
|
| 189 |
-
x = l(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
x = self.mid(x)
|
| 191 |
return self.conv_out(F.silu(self.norm_out(x)))
|
| 192 |
|
|
@@ -637,7 +648,8 @@ class CLIPTextTransformer(Module):
|
|
| 637 |
self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
|
| 638 |
self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
|
| 639 |
# 上三角都是 -inf 值
|
| 640 |
-
|
|
|
|
| 641 |
self.name = name
|
| 642 |
|
| 643 |
def forward(self, input_ids):
|
|
@@ -804,7 +816,7 @@ class StableDiffusion(Module):
|
|
| 804 |
|
| 805 |
|
| 806 |
class Args(object):
|
| 807 |
-
def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file):
|
| 808 |
self.phrase = phrase
|
| 809 |
self.steps = steps
|
| 810 |
self.model_type = model_type
|
|
@@ -814,22 +826,41 @@ class Args(object):
|
|
| 814 |
self.seed = seed
|
| 815 |
self.device = device
|
| 816 |
self.model_file = model_file
|
|
|
|
|
|
|
|
|
|
|
|
|
| 817 |
|
|
|
|
| 818 |
|
| 819 |
-
class
|
| 820 |
_instance_lock = threading.Lock()
|
| 821 |
def __init__(self, args: Args):
|
| 822 |
-
super(
|
| 823 |
self.is_load_model=False
|
| 824 |
self.args = args
|
| 825 |
self.model = StableDiffusion().instance()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
|
| 827 |
@classmethod
|
| 828 |
def instance(cls, *args, **kwargs):
|
| 829 |
-
with
|
| 830 |
-
if not hasattr(
|
| 831 |
-
|
| 832 |
-
return
|
| 833 |
|
| 834 |
def load_model(self):
|
| 835 |
if self.args.model_file != "" and self.is_load_model==False:
|
|
@@ -841,6 +872,7 @@ class Text2img(Module):
|
|
| 841 |
def get_token_encode(self, phrase):
|
| 842 |
tokenizer = ClipTokenizer().instance()
|
| 843 |
phrase = tokenizer.encode(phrase)
|
|
|
|
| 844 |
with torch.no_grad():
|
| 845 |
context = self.model.text_decoder(phrase)
|
| 846 |
return context.to(self.args.device)
|
|
@@ -848,7 +880,7 @@ class Text2img(Module):
|
|
| 848 |
self.set_seeds(True)
|
| 849 |
self.load_model()
|
| 850 |
context = self.get_token_encode(phrase)
|
| 851 |
-
unconditional_context = self.get_token_encode(
|
| 852 |
|
| 853 |
timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
|
| 854 |
print(f"running for {timesteps} timesteps")
|
|
@@ -857,9 +889,26 @@ class Text2img(Module):
|
|
| 857 |
|
| 858 |
latent_width = int(self.args.img_width) // 8
|
| 859 |
latent_height = int(self.args.img_height) // 8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 860 |
# start with random noise
|
| 861 |
-
latent =
|
|
|
|
| 862 |
latent = latent.to(self.args.device)
|
|
|
|
| 863 |
with torch.no_grad():
|
| 864 |
# this is diffusion
|
| 865 |
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
|
@@ -867,11 +916,14 @@ class Text2img(Module):
|
|
| 867 |
e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
|
| 868 |
unconditional_context.clone())
|
| 869 |
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
|
|
|
|
| 870 |
# e_t_next = get_model_output(x_prev)
|
| 871 |
# e_t_prime = (e_t + e_t_next) / 2
|
| 872 |
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
| 873 |
-
|
| 874 |
-
|
|
|
|
|
|
|
| 875 |
|
| 876 |
def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
|
| 877 |
temperature = 1
|
|
@@ -900,6 +952,27 @@ class Text2img(Module):
|
|
| 900 |
del unconditional_latent, latent, timesteps, context
|
| 901 |
return e_t
|
| 902 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 903 |
def latent_decode(self, latent, latent_height, latent_width):
|
| 904 |
# upsample latent space to image with autoencoder
|
| 905 |
# x = model.first_stage_model.post_quant_conv( 8* latent)
|
|
@@ -915,8 +988,7 @@ class Text2img(Module):
|
|
| 915 |
return decode_latent
|
| 916 |
def decode_latent2img(self, decode_latent):
|
| 917 |
# save image
|
| 918 |
-
|
| 919 |
-
img = Image.fromarray(decode_latent)
|
| 920 |
return img
|
| 921 |
|
| 922 |
def set_seeds(self, cuda):
|
|
@@ -925,11 +997,11 @@ class Text2img(Module):
|
|
| 925 |
if cuda:
|
| 926 |
torch.cuda.manual_seed_all(self.args.seed)
|
| 927 |
@lru_cache()
|
| 928 |
-
def
|
| 929 |
try:
|
| 930 |
-
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
| 931 |
-
im =
|
| 932 |
-
im =
|
| 933 |
finally:
|
| 934 |
pass
|
| 935 |
return im
|
|
@@ -943,19 +1015,20 @@ if __name__ == "__main__":
|
|
| 943 |
|
| 944 |
parser = argparse.ArgumentParser(description='Run Stable Diffusion',
|
| 945 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 946 |
-
parser.add_argument('--steps', type=int, default=
|
| 947 |
parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
|
|
|
|
| 948 |
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
| 949 |
parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
|
| 950 |
-
parser.add_argument('--model_file', type=str, default="/
|
| 951 |
parser.add_argument('--img_width', type=int, default=512, help="output image width")
|
| 952 |
parser.add_argument('--img_height', type=int, default=512, help="output image height")
|
| 953 |
parser.add_argument('--seed', type=int, default=443, help="random seed")
|
| 954 |
-
parser.add_argument('--device_type', type=str, default="cpu", help="
|
|
|
|
| 955 |
args = parser.parse_args()
|
| 956 |
-
|
| 957 |
device = args.device_type
|
| 958 |
|
| 959 |
-
im =
|
| 960 |
print(f"saving {args.out}")
|
| 961 |
im.save(args.out)
|
|
|
|
| 12 |
import numpy as np
|
| 13 |
from tqdm import tqdm
|
| 14 |
|
| 15 |
+
# ,
|
| 16 |
+
from torch.nn import Conv2d, Linear, Module, SiLU, UpsamplingNearest2d,ModuleList,ZeroPad2d
|
| 17 |
from torch import Tensor
|
| 18 |
from torch.nn import functional as F
|
| 19 |
from torch.nn.parameter import Parameter
|
| 20 |
|
| 21 |
+
device = "mps"
|
| 22 |
|
| 23 |
def apply_seq(seqs, x):
|
| 24 |
for seq in seqs:
|
|
|
|
| 31 |
def quick_gelu(x):
|
| 32 |
return x * torch.sigmoid(x * 1.702)
|
| 33 |
|
| 34 |
+
# class SiLU(Module):
|
| 35 |
+
# def __init__(self):
|
| 36 |
+
# super(SiLU, self).__init__()
|
| 37 |
+
# self.gelu = quick_gelu
|
| 38 |
+
# def forward(self,x ):
|
| 39 |
+
# return self.gelu(x)
|
| 40 |
class Normalize(Module):
|
| 41 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
| 42 |
super(Normalize, self).__init__()
|
|
|
|
| 172 |
self.down = ModuleList([
|
| 173 |
ResnetBlock(128, 128, name=name + "_down_block_0_0_ResnetBlock"),
|
| 174 |
ResnetBlock(128, 128, name=name + "_down_block_0_1_ResnetBlock"),
|
| 175 |
+
Conv2d(128, 128, 3, stride=2, padding=(0, 0)),
|
| 176 |
ResnetBlock(128, 256, name=name + "_down_block_1_0_ResnetBlock"),
|
| 177 |
ResnetBlock(256, 256, name=name + "_down_block_1_1_ResnetBlock"),
|
| 178 |
+
Conv2d(256, 256, 3, stride=2, padding=(0, 0)),
|
| 179 |
ResnetBlock(256, 512, name=name + "_down_block_2_0_ResnetBlock"),
|
| 180 |
ResnetBlock(512, 512, name=name + "_down_block_2_1_ResnetBlock"),
|
| 181 |
+
Conv2d(512, 512, 3, stride=2, padding=(0, 0)),
|
| 182 |
ResnetBlock(512, 512, name=name + "_down_block_3_0_ResnetBlock"),
|
| 183 |
ResnetBlock(512, 512, name=name + "_down_block_3_1_ResnetBlock"),
|
| 184 |
])
|
|
|
|
| 187 |
self.norm_out = Normalize(512, name=name+"_norm_out_Normalize")
|
| 188 |
self.conv_out = Conv2d(512, 8, 3, padding=1)
|
| 189 |
self.name = name
|
| 190 |
+
self.zero_pad2d_0_1 = ZeroPad2d((0,1,0,1))
|
| 191 |
|
| 192 |
def forward(self, x):
|
| 193 |
x = self.conv_in(x)
|
| 194 |
|
| 195 |
for l in self.down:
|
| 196 |
+
# x = l(x)
|
| 197 |
+
if isinstance(l, Conv2d):
|
| 198 |
+
x = l(self.zero_pad2d_0_1(x))
|
| 199 |
+
else:
|
| 200 |
+
x = l(x)
|
| 201 |
x = self.mid(x)
|
| 202 |
return self.conv_out(F.silu(self.norm_out(x)))
|
| 203 |
|
|
|
|
| 648 |
self.encoder = CLIPEncoder(name=name+"_CLIPEncoder_0")
|
| 649 |
self.final_layer_norm = Normalize(768, num_groups=None, name=name+"_CLIPTextTransformer_normalizer_0")
|
| 650 |
# 上三角都是 -inf 值
|
| 651 |
+
triu = np.triu(np.ones((1, 1, 77, 77), dtype=np.float32) * -np.inf, k=1)
|
| 652 |
+
self.causal_attention_mask = Tensor(triu).to(device)
|
| 653 |
self.name = name
|
| 654 |
|
| 655 |
def forward(self, input_ids):
|
|
|
|
| 816 |
|
| 817 |
|
| 818 |
class Args(object):
|
| 819 |
+
def __init__(self, phrase, steps, model_type, guidance_scale, img_width, img_height, seed, device, model_file, input_image:str="", input_mask:str="", input_image_strength=0.5, unphrase=""):
|
| 820 |
self.phrase = phrase
|
| 821 |
self.steps = steps
|
| 822 |
self.model_type = model_type
|
|
|
|
| 826 |
self.seed = seed
|
| 827 |
self.device = device
|
| 828 |
self.model_file = model_file
|
| 829 |
+
self.input_image = input_image
|
| 830 |
+
self.input_mask = input_mask
|
| 831 |
+
self.input_image_strength = input_image_strength
|
| 832 |
+
self.unphrase = unphrase
|
| 833 |
|
| 834 |
+
from PIL import Image
|
| 835 |
|
| 836 |
+
class Generate2img(Module):
|
| 837 |
_instance_lock = threading.Lock()
|
| 838 |
def __init__(self, args: Args):
|
| 839 |
+
super(Generate2img, self).__init__()
|
| 840 |
self.is_load_model=False
|
| 841 |
self.args = args
|
| 842 |
self.model = StableDiffusion().instance()
|
| 843 |
+
self.get_input_image_tensor()
|
| 844 |
+
# self.get_input_mask_tensor()
|
| 845 |
+
|
| 846 |
+
|
| 847 |
+
def get_input_image_tensor(self):
|
| 848 |
+
if self.args.input_image!="":
|
| 849 |
+
input_image = Image.open(args.input_image).convert("RGB").resize((self.args.img_width, self.args.img_height), resample=Image.Resampling.LANCZOS)
|
| 850 |
+
self.input_image_array = torch.from_numpy(np.array(input_image)).to(device)
|
| 851 |
+
self.input_image_tensor = torch.from_numpy((np.array(input_image, dtype=np.float32)[None, ..., :3]/ 255.0*2.0-1))
|
| 852 |
+
self.input_image_tensor = self.input_image_tensor.permute(0, 3, 1, 2) # bs, channel, height, width
|
| 853 |
+
else:
|
| 854 |
+
self.input_image_tensor = None
|
| 855 |
+
return self.input_image_tensor
|
| 856 |
+
|
| 857 |
|
| 858 |
@classmethod
|
| 859 |
def instance(cls, *args, **kwargs):
|
| 860 |
+
with Generate2img._instance_lock:
|
| 861 |
+
if not hasattr(Generate2img, "_instance"):
|
| 862 |
+
Generate2img._instance = Generate2img(*args, **kwargs)
|
| 863 |
+
return Generate2img._instance
|
| 864 |
|
| 865 |
def load_model(self):
|
| 866 |
if self.args.model_file != "" and self.is_load_model==False:
|
|
|
|
| 872 |
def get_token_encode(self, phrase):
|
| 873 |
tokenizer = ClipTokenizer().instance()
|
| 874 |
phrase = tokenizer.encode(phrase)
|
| 875 |
+
# phrase = phrase + [49407] * (77 - len(phrase))
|
| 876 |
with torch.no_grad():
|
| 877 |
context = self.model.text_decoder(phrase)
|
| 878 |
return context.to(self.args.device)
|
|
|
|
| 880 |
self.set_seeds(True)
|
| 881 |
self.load_model()
|
| 882 |
context = self.get_token_encode(phrase)
|
| 883 |
+
unconditional_context = self.get_token_encode(self.args.unphrase)
|
| 884 |
|
| 885 |
timesteps = list(np.arange(1, 1000, 1000 // self.args.steps))
|
| 886 |
print(f"running for {timesteps} timesteps")
|
|
|
|
| 889 |
|
| 890 |
latent_width = int(self.args.img_width) // 8
|
| 891 |
latent_height = int(self.args.img_height) // 8
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
input_image_latent = None
|
| 895 |
+
input_img_noise_t = None
|
| 896 |
+
if self.input_image_tensor!=None:
|
| 897 |
+
noise_index = int(len(timesteps) * self.args.input_image_strength)
|
| 898 |
+
if noise_index >= len(timesteps):
|
| 899 |
+
noise_index = noise_index - 1
|
| 900 |
+
input_img_noise_t = timesteps[noise_index]
|
| 901 |
+
with torch.no_grad():
|
| 902 |
+
filter = lambda x:x[:,:4,:,:] * 0.18215
|
| 903 |
+
input_image_latent = self.model.first_stage_model.encoder(self.input_image_tensor.to(device))
|
| 904 |
+
input_image_latent = self.model.first_stage_model.quant_conv(input_image_latent)
|
| 905 |
+
input_image_latent = filter(input_image_latent) # only the means
|
| 906 |
+
|
| 907 |
# start with random noise
|
| 908 |
+
latent = self.get_noise_latent( 1, latent_height, latent_width, input_image_latent, input_img_noise_t, None)
|
| 909 |
+
|
| 910 |
latent = latent.to(self.args.device)
|
| 911 |
+
|
| 912 |
with torch.no_grad():
|
| 913 |
# this is diffusion
|
| 914 |
for index, timestep in (t := tqdm(list(enumerate(timesteps))[::-1])):
|
|
|
|
| 916 |
e_t = self.get_model_latent_output(latent.clone(), timestep, self.model.unet, context.clone(),
|
| 917 |
unconditional_context.clone())
|
| 918 |
x_prev, pred_x0 = self.get_x_prev_and_pred_x0(latent, e_t, index, alphas, alphas_prev)
|
| 919 |
+
latent = x_prev
|
| 920 |
# e_t_next = get_model_output(x_prev)
|
| 921 |
# e_t_prime = (e_t + e_t_next) / 2
|
| 922 |
# x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
|
| 923 |
+
decode = self.latent_decode(latent, latent_height, latent_width)
|
| 924 |
+
|
| 925 |
+
return decode
|
| 926 |
+
|
| 927 |
|
| 928 |
def get_x_prev_and_pred_x0(self, x, e_t, index, alphas, alphas_prev):
|
| 929 |
temperature = 1
|
|
|
|
| 952 |
del unconditional_latent, latent, timesteps, context
|
| 953 |
return e_t
|
| 954 |
|
| 955 |
+
|
| 956 |
+
def add_noise(self, x , t , noise=None ):
|
| 957 |
+
# batch_size, channel, h, w = x.shape
|
| 958 |
+
if noise is None:
|
| 959 |
+
noise = torch.normal(0,1, size=(x.shape))
|
| 960 |
+
# sqrt_alpha_prod = _ALPHAS_CUMPROD[t] ** 0.5
|
| 961 |
+
sqrt_alpha_prod = self.model.sqrt_alphas_cumprod[t]
|
| 962 |
+
sqrt_one_minus_alpha_prod = self.model.sqrt_one_minus_alphas_cumprod[t]
|
| 963 |
+
# sqrt_one_minus_alpha_prod = (1 - _ALPHAS_CUMPROD[t]) ** 0.5
|
| 964 |
+
|
| 965 |
+
return sqrt_alpha_prod * x + sqrt_one_minus_alpha_prod * noise.to(device)
|
| 966 |
+
|
| 967 |
+
def get_noise_latent(self, batch_size, latent_height, latent_width, input_image_latent=None, input_img_noise_t=None, noise=None):
|
| 968 |
+
|
| 969 |
+
if input_image_latent is None:
|
| 970 |
+
latent = torch.normal(0,1, size=(batch_size, 4, latent_height, latent_width))
|
| 971 |
+
# latent = torch.randn((batch_size, 4, latent_height, latent_width))
|
| 972 |
+
else:
|
| 973 |
+
latent = self.add_noise(input_image_latent, input_img_noise_t, noise)
|
| 974 |
+
return latent.to(device)
|
| 975 |
+
|
| 976 |
def latent_decode(self, latent, latent_height, latent_width):
|
| 977 |
# upsample latent space to image with autoencoder
|
| 978 |
# x = model.first_stage_model.post_quant_conv( 8* latent)
|
|
|
|
| 988 |
return decode_latent
|
| 989 |
def decode_latent2img(self, decode_latent):
|
| 990 |
# save image
|
| 991 |
+
img = Image.fromarray(decode_latent, mode="RGB")
|
|
|
|
| 992 |
return img
|
| 993 |
|
| 994 |
def set_seeds(self, cuda):
|
|
|
|
| 997 |
if cuda:
|
| 998 |
torch.cuda.manual_seed_all(self.args.seed)
|
| 999 |
@lru_cache()
|
| 1000 |
+
def generate2img(phrase, steps, model_file, guidance_scale, img_width, img_height, seed, device, input_image, input_mask, input_image_strength=0.5, unphrase=""):
|
| 1001 |
try:
|
| 1002 |
+
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file, input_image, input_mask, input_image_strength, unphrase)
|
| 1003 |
+
im = Generate2img.instance(args).forward(args.phrase)
|
| 1004 |
+
im = Generate2img.instance(args).decode_latent2img(im)
|
| 1005 |
finally:
|
| 1006 |
pass
|
| 1007 |
return im
|
|
|
|
| 1015 |
|
| 1016 |
parser = argparse.ArgumentParser(description='Run Stable Diffusion',
|
| 1017 |
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
| 1018 |
+
parser.add_argument('--steps', type=int, default=50, help="Number of steps in diffusion")
|
| 1019 |
parser.add_argument('--phrase', type=str, default="anthropomorphic cat portrait art ", help="Phrase to render")
|
| 1020 |
+
parser.add_argument('--unphrase', type=str, default="", help="unconditional Phrase to render")
|
| 1021 |
parser.add_argument('--out', type=str, default="/tmp/rendered.png", help="Output filename")
|
| 1022 |
parser.add_argument('--scale', type=float, default=7.5, help="unconditional guidance scale")
|
| 1023 |
+
parser.add_argument('--model_file', type=str, default="../min-stable-diffusion-pt/mdjrny-v4.pt", help="model weight file")
|
| 1024 |
parser.add_argument('--img_width', type=int, default=512, help="output image width")
|
| 1025 |
parser.add_argument('--img_height', type=int, default=512, help="output image height")
|
| 1026 |
parser.add_argument('--seed', type=int, default=443, help="random seed")
|
| 1027 |
+
parser.add_argument('--device_type', type=str, default="cpu", help="device type, support: cpu;cuda;mps")
|
| 1028 |
+
parser.add_argument('--input_image', type=str, default="", help="input image file")
|
| 1029 |
args = parser.parse_args()
|
|
|
|
| 1030 |
device = args.device_type
|
| 1031 |
|
| 1032 |
+
im = generate2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type, args.input_image, "", 1, args.unphrase)
|
| 1033 |
print(f"saving {args.out}")
|
| 1034 |
im.save(args.out)
|