phi35-moe-demo / preinstall.py
ianshank's picture
πŸš€ Final fix v20250913_220639: Comprehensive solution for dependency and configuration issues
3eeba36 verified
#!/usr/bin/env python3
"""
Pre-installation script for Phi-3.5-MoE Space
Installs required dependencies and selects CPU-safe model revision if needed
"""
import os
import sys
import subprocess
import torch
import re
from pathlib import Path
from huggingface_hub import HfApi
def install_dependencies():
"""Install required dependencies based on environment."""
print("πŸ”§ Installing required dependencies...")
# Always install einops
subprocess.check_call([sys.executable, "-m", "pip", "install", "einops>=0.7.0"])
print("βœ… Installed einops")
# Install flash-attn only if CUDA is available
if torch.cuda.is_available():
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn>=2.6.0", "--no-build-isolation"])
print("βœ… Installed flash-attn for GPU runtime")
except subprocess.CalledProcessError:
print("⚠️ Failed to install flash-attn, continuing without it")
else:
print("ℹ️ CPU runtime detected: skipping flash-attn installation")
def select_cpu_safe_revision():
"""Select a CPU-safe model revision by checking commit history."""
if torch.cuda.is_available() or os.getenv("HF_REVISION"):
return
MODEL_ID = os.getenv("HF_MODEL_ID", "microsoft/Phi-3.5-MoE-instruct")
TARGET_FILE = "modeling_phimoe.py"
ENV_FILE = ".env"
print(f"πŸ” Selecting CPU-safe revision for {MODEL_ID}...")
try:
api = HfApi()
for commit in api.list_repo_commits(MODEL_ID, repo_type="model"):
sha = commit.commit_id
try:
file_path = api.hf_hub_download(MODEL_ID, TARGET_FILE, revision=sha, repo_type="model")
with open(file_path, "r", encoding="utf-8") as f:
code = f.read()
# Check if this version doesn't have flash_attn as a top-level import
if not re.search(r'^\s*import\s+flash_attn|^\s*from\s+flash_attn', code, flags=re.M):
# Write to .env file
with open(ENV_FILE, "a", encoding="utf-8") as env_file:
env_file.write(f"HF_REVISION={sha}\n")
# Also set it in the current environment
os.environ["HF_REVISION"] = sha
print(f"βœ… Selected CPU-safe revision: {sha}")
return
except Exception:
continue
print("⚠️ No CPU-safe revision found")
except Exception as e:
print(f"⚠️ Error selecting CPU-safe revision: {e}")
def create_model_patch():
"""Create a patch file to fix the model loading code."""
PATCH_FILE = "model_patch.py"
patch_content = """
# Monkey patch for transformers.dynamic_module_utils
import sys
import importlib
from importlib.abc import Loader
from importlib.machinery import ModuleSpec
from transformers.dynamic_module_utils import check_imports
# Create mock modules for missing dependencies
class MockModule:
def __init__(self, name):
self.__name__ = name
self.__spec__ = ModuleSpec(name, None)
def __getattr__(self, key):
return MockModule(f"{self.__name__}.{key}")
# Override check_imports to handle missing dependencies
original_check_imports = check_imports
def patched_check_imports(resolved_module_file):
try:
return original_check_imports(resolved_module_file)
except ImportError as e:
# Extract missing modules
import re
missing = re.findall(r'packages that were not found in your environment: ([^.]+)', str(e))
if missing:
missing_modules = [m.strip() for m in missing[0].split(',')]
print(f"⚠️ Missing dependencies: {', '.join(missing_modules)}")
print("πŸ”§ Creating mock modules to continue loading...")
# Create mock modules
for module_name in missing_modules:
if module_name not in sys.modules:
mock_module = MockModule(module_name)
sys.modules[module_name] = mock_module
print(f"βœ… Created mock for {module_name}")
# Try again
return original_check_imports(resolved_module_file)
else:
raise
# Apply the patch
from transformers import dynamic_module_utils
dynamic_module_utils.check_imports = patched_check_imports
print("βœ… Applied transformers patch for handling missing dependencies")
"""
with open(PATCH_FILE, "w", encoding="utf-8") as f:
f.write(patch_content)
print(f"βœ… Created model patch file: {PATCH_FILE}")
if __name__ == "__main__":
print("πŸš€ Running pre-installation script...")
install_dependencies()
select_cpu_safe_revision()
create_model_patch()
print("βœ… Pre-installation complete!")