Spaces:
Build error
Build error
add quick_gelu
Browse files- stable_diffusion.py +7 -3
stable_diffusion.py
CHANGED
|
@@ -28,6 +28,9 @@ def apply_seq(seqs, x):
|
|
| 28 |
def gelu(self):
|
| 29 |
return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
class Normalize(Module):
|
| 32 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
| 33 |
super(Normalize, self).__init__()
|
|
@@ -275,7 +278,7 @@ class GEGLU(Module):
|
|
| 275 |
|
| 276 |
def forward(self, x):
|
| 277 |
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 278 |
-
return x *
|
| 279 |
|
| 280 |
class FeedForward(Module):
|
| 281 |
def __init__(self, dim, mult=4, name="FeedForward"):
|
|
@@ -523,7 +526,7 @@ class CLIPMLP(Module):
|
|
| 523 |
|
| 524 |
def forward(self, hidden_states):
|
| 525 |
hidden_states = self.fc1(hidden_states)
|
| 526 |
-
hidden_states =
|
| 527 |
hidden_states = self.fc2(hidden_states)
|
| 528 |
return hidden_states
|
| 529 |
|
|
@@ -926,6 +929,7 @@ def text2img(phrase, steps, model_file, guidance_scale, img_width, img_height, s
|
|
| 926 |
try:
|
| 927 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
| 928 |
im = Text2img.instance(args).forward(args.phrase)
|
|
|
|
| 929 |
finally:
|
| 930 |
pass
|
| 931 |
return im
|
|
@@ -954,4 +958,4 @@ if __name__ == "__main__":
|
|
| 954 |
|
| 955 |
im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
|
| 956 |
print(f"saving {args.out}")
|
| 957 |
-
im.save(args.out)
|
|
|
|
| 28 |
def gelu(self):
|
| 29 |
return 0.5 * self * (1 + torch.tanh(self * 0.7978845608 * (1 + 0.044715 * self * self)))
|
| 30 |
|
| 31 |
+
def quick_gelu(x):
|
| 32 |
+
return x * torch.sigmoid(x * 1.702)
|
| 33 |
+
|
| 34 |
class Normalize(Module):
|
| 35 |
def __init__(self, in_channels, num_groups=32, name="normalize"):
|
| 36 |
super(Normalize, self).__init__()
|
|
|
|
| 278 |
|
| 279 |
def forward(self, x):
|
| 280 |
x, gate = self.proj(x).chunk(2, dim=-1)
|
| 281 |
+
return x * quick_gelu(gate)
|
| 282 |
|
| 283 |
class FeedForward(Module):
|
| 284 |
def __init__(self, dim, mult=4, name="FeedForward"):
|
|
|
|
| 526 |
|
| 527 |
def forward(self, hidden_states):
|
| 528 |
hidden_states = self.fc1(hidden_states)
|
| 529 |
+
hidden_states = quick_gelu(hidden_states)
|
| 530 |
hidden_states = self.fc2(hidden_states)
|
| 531 |
return hidden_states
|
| 532 |
|
|
|
|
| 929 |
try:
|
| 930 |
args = Args(phrase, steps, None, guidance_scale, img_width, img_height, seed, device, model_file)
|
| 931 |
im = Text2img.instance(args).forward(args.phrase)
|
| 932 |
+
im = Text2img.instance(args).decode_latent2img(im)
|
| 933 |
finally:
|
| 934 |
pass
|
| 935 |
return im
|
|
|
|
| 958 |
|
| 959 |
im = text2img(args.phrase, args.steps, args.model_file, args.scale, args.img_width, args.img_height, args.seed, args.device_type)
|
| 960 |
print(f"saving {args.out}")
|
| 961 |
+
im.save(args.out)
|