Spaces:
Runtime error
Runtime error
Commit
Β·
ff2c5f2
1
Parent(s):
f412da9
Update gligen/ldm/models/diffusion/plms.py
Browse files
gligen/ldm/models/diffusion/plms.py
CHANGED
|
@@ -151,13 +151,14 @@ class PLMSSampler(object):
|
|
| 151 |
object_positions=object_positions, t = index1)*loss_scale
|
| 152 |
loss = loss1 + loss2
|
| 153 |
print('loss', loss, loss1, loss2)
|
| 154 |
-
hh = torch.autograd.backward(loss, retain_graph=True)
|
| 155 |
-
grad_cond =
|
|
|
|
| 156 |
x = x - grad_cond
|
| 157 |
x = x.detach()
|
| 158 |
iteration += 1
|
| 159 |
|
| 160 |
-
|
| 161 |
return x
|
| 162 |
|
| 163 |
def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
|
|
|
|
| 151 |
object_positions=object_positions, t = index1)*loss_scale
|
| 152 |
loss = loss1 + loss2
|
| 153 |
print('loss', loss, loss1, loss2)
|
| 154 |
+
# hh = torch.autograd.backward(loss, retain_graph=True)
|
| 155 |
+
grad_cond = torch.autograd.grad(loss.requires_grad_(True), [x])[0]
|
| 156 |
+
# grad_cond = x.grad
|
| 157 |
x = x - grad_cond
|
| 158 |
x = x.detach()
|
| 159 |
iteration += 1
|
| 160 |
|
| 161 |
+
|
| 162 |
return x
|
| 163 |
|
| 164 |
def update_loss_only_cross(self, input,index1, index, ts,type_loss='self_accross'):
|