Spaces:
Paused
Paused
Update hymm_sp/modules/models_audio.py
Browse files- hymm_sp/modules/models_audio.py +22 -21
hymm_sp/modules/models_audio.py
CHANGED
|
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|
| 7 |
from diffusers.models import ModelMixin
|
| 8 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
|
| 10 |
-
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
| 11 |
|
| 12 |
|
| 13 |
|
|
@@ -173,29 +173,30 @@ class DoubleStreamBlock(nn.Module):
|
|
| 173 |
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
| 174 |
for x in [q, k, v]
|
| 175 |
]
|
|
|
|
| 176 |
|
| 177 |
-
attn = flash_attn_varlen_func(
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
)
|
| 186 |
attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
|
| 187 |
else:
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
)
|
| 199 |
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
|
| 200 |
|
| 201 |
if CPU_OFFLOAD: torch.cuda.empty_cache()
|
|
|
|
| 7 |
from diffusers.models import ModelMixin
|
| 8 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
|
| 10 |
+
#from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
| 11 |
|
| 12 |
|
| 13 |
|
|
|
|
| 173 |
x.view(x.shape[0] * x.shape[1], *x.shape[2:])
|
| 174 |
for x in [q, k, v]
|
| 175 |
]
|
| 176 |
+
attn = None
|
| 177 |
|
| 178 |
+
# attn = flash_attn_varlen_func(
|
| 179 |
+
# q,
|
| 180 |
+
# k,
|
| 181 |
+
# v,
|
| 182 |
+
# cu_seqlens_q,
|
| 183 |
+
# cu_seqlens_kv,
|
| 184 |
+
# max_seqlen_q,
|
| 185 |
+
# max_seqlen_kv,
|
| 186 |
+
# )
|
| 187 |
attn = attn.view(img_k.shape[0], max_seqlen_q, -1).contiguous()
|
| 188 |
else:
|
| 189 |
+
# attn, _ = parallel_attention(
|
| 190 |
+
# (img_q, txt_q),
|
| 191 |
+
# (img_k, txt_k),
|
| 192 |
+
# (img_v, txt_v),
|
| 193 |
+
# img_q_len=img_q.shape[1],
|
| 194 |
+
# img_kv_len=img_k.shape[1],
|
| 195 |
+
# cu_seqlens_q=cu_seqlens_q,
|
| 196 |
+
# cu_seqlens_kv=cu_seqlens_kv,
|
| 197 |
+
# max_seqlen_q=max_seqlen_q,
|
| 198 |
+
# max_seqlen_kv=max_seqlen_kv,
|
| 199 |
+
# )
|
| 200 |
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
|
| 201 |
|
| 202 |
if CPU_OFFLOAD: torch.cuda.empty_cache()
|