cbensimon HF Staff commited on
Commit
8406266
·
1 Parent(s): 0729c7e
Files changed (1) hide show
  1. aoti.py +9 -4
aoti.py CHANGED
@@ -1,6 +1,8 @@
1
  """
2
  """
3
 
 
 
4
  import torch
5
  from huggingface_hub import hf_hub_download
6
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel
@@ -13,14 +15,17 @@ def _shallow_clone_module(module: torch.nn.Module) -> torch.nn.Module:
13
  clone.__dict__ = module.__dict__.copy()
14
  clone._parameters = module._parameters.copy()
15
  clone._buffers = module._buffers.copy()
16
- clone._modules = {k: shallow_clone_module(v) for k, v in module._modules.items()}
17
  return clone
18
 
19
 
20
  def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
21
- repeated_blocks = module._repeated_blocks
22
- subfolder = name if variant is None else f'{name}.{variant}'
23
- aoti_files = {name: hf_hub_download(repo_id, 'package.pt2', subfolder=subfolder) for name in repeated_blocks}
 
 
 
24
  for block_name, aoti_file in aoti_files.items():
25
  for block in module.modules():
26
  if block.__class__.__name__ == block_name:
 
1
  """
2
  """
3
 
4
+ from typing import cast
5
+
6
  import torch
7
  from huggingface_hub import hf_hub_download
8
  from spaces.zero.torch.aoti import ZeroGPUCompiledModel
 
15
  clone.__dict__ = module.__dict__.copy()
16
  clone._parameters = module._parameters.copy()
17
  clone._buffers = module._buffers.copy()
18
+ clone._modules = {k: _shallow_clone_module(v) for k, v in module._modules.items() if v is not None}
19
  return clone
20
 
21
 
22
  def aoti_blocks_load(module: torch.nn.Module, repo_id: str, variant: str | None = None):
23
+ repeated_blocks = cast(list[str], module._repeated_blocks)
24
+ aoti_files = {name: hf_hub_download(
25
+ repo_id=repo_id,
26
+ filename='package.pt2',
27
+ subfolder=name if variant is None else f'{name}.{variant}',
28
+ ) for name in repeated_blocks}
29
  for block_name, aoti_file in aoti_files.items():
30
  for block in module.modules():
31
  if block.__class__.__name__ == block_name: