import gradio as gr
import os
import tempfile
import shutil
import re
import json
import datetime
from pathlib import Path
from huggingface_hub import HfApi, hf_hub_download
from safetensors.torch import load_file, save_file
import torch
import torch.nn.functional as F
import traceback
import math
try:
from modelscope.hub.file_download import model_file_download as ms_file_download
from modelscope.hub.api import HubApi as ModelScopeApi
MODELScope_AVAILABLE = True
except ImportError:
MODELScope_AVAILABLE = False
def get_fp8_dtype(fp8_format):
"""Get torch FP8 dtype."""
if fp8_format == "e5m2":
return torch.float8_e5m2
else:
return torch.float8_e4m3fn
def quantize_and_get_error(weight, fp8_dtype):
"""Quantize weight to FP8 and return both quantized weight and error."""
weight_fp8 = weight.to(fp8_dtype)
weight_dequantized = weight_fp8.to(weight.dtype)
error = weight - weight_dequantized
return weight_fp8, error
def low_rank_decomposition_error(error_tensor, rank=32, min_error_threshold=1e-6):
"""Decompose error tensor with proper rank reduction."""
if error_tensor.ndim not in [2, 4]:
return None, None
try:
# Calculate error magnitude
error_norm = torch.norm(error_tensor.float())
if error_norm < min_error_threshold:
return None, None
# For 2D tensors (linear layers)
if error_tensor.ndim == 2:
U, S, Vh = torch.linalg.svd(error_tensor.float(), full_matrices=False)
# Calculate rank based on variance explained (keep 95% of error)
total_variance = torch.sum(S ** 2)
cumulative = torch.cumsum(S ** 2, dim=0)
keep_components = torch.sum(cumulative <= 0.95 * total_variance).item() + 1
# Limit rank to much smaller than original
max_rank = min(error_tensor.shape)
actual_rank = min(rank, keep_components, max_rank // 2)
if actual_rank < 2:
return None, None
A = Vh[:actual_rank, :].contiguous()
B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
return A, B
# For 4D convolutions
elif error_tensor.ndim == 4:
out_ch, in_ch, kH, kW = error_tensor.shape
# Reshape to 2D for decomposition
error_2d = error_tensor.view(out_ch, in_ch * kH * kW)
U, S, Vh = torch.linalg.svd(error_2d.float(), full_matrices=False)
# Calculate rank based on variance explained (90% for conv)
total_variance = torch.sum(S ** 2)
cumulative = torch.cumsum(S ** 2, dim=0)
keep_components = torch.sum(cumulative <= 0.90 * total_variance).item() + 1
# Use even lower rank for conv
max_rank = min(error_2d.shape)
actual_rank = min(rank // 2, keep_components, max_rank // 4)
if actual_rank < 2:
return None, None
A = Vh[:actual_rank, :].contiguous()
B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous()
# Reshape back for convolutional format
if kH == 1 and kW == 1:
B = B.view(out_ch, actual_rank, 1, 1)
A = A.view(actual_rank, in_ch, 1, 1)
else:
B = B.view(out_ch, actual_rank, 1, 1)
A = A.view(actual_rank, in_ch, kH, kW)
return A, B
except Exception as e:
print(f"Error decomposition failed: {e}")
return None, None
def extract_correction_factors(original_weight, fp8_weight):
"""Extract simple correction factors for VAE."""
with torch.no_grad():
orig = original_weight.float()
quant = fp8_weight.float()
error = orig - quant
error_norm = torch.norm(error)
orig_norm = torch.norm(orig)
if orig_norm > 1e-6 and error_norm / orig_norm < 0.001:
return None
# For 4D tensors (VAE), compute per-channel correction
if orig.ndim == 4:
channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True)
return channel_mean.to(original_weight.dtype)
elif orig.ndim == 2:
row_mean = error.mean(dim=1, keepdim=True)
return row_mean.to(original_weight.dtype)
else:
return error.mean().to(original_weight.dtype)
def get_architecture_settings(architecture, base_rank):
"""Get optimal settings for different architectures."""
settings = {
"text_encoder": {
"rank": base_rank,
"error_threshold": 5e-5,
"min_rank": 8,
"max_rank_factor": 0.4,
"method": "lora"
},
"transformer": {
"rank": base_rank,
"error_threshold": 1e-5,
"min_rank": 12,
"max_rank_factor": 0.35,
"method": "lora"
},
"vae": {
"rank": base_rank // 2,
"error_threshold": 1e-4,
"min_rank": 4,
"max_rank_factor": 0.3,
"method": "correction"
},
"unet_conv": {
"rank": base_rank // 3,
"error_threshold": 2e-5,
"min_rank": 8,
"max_rank_factor": 0.25,
"method": "lora"
},
"auto": {
"rank": base_rank,
"error_threshold": 1e-5,
"min_rank": 8,
"max_rank_factor": 0.3,
"method": "lora"
},
"all": {
"rank": base_rank,
"error_threshold": 1e-5,
"min_rank": 8,
"max_rank_factor": 0.3,
"method": "lora"
}
}
return settings.get(architecture, settings["auto"])
def should_process_layer(key, weight, architecture):
"""Determine if layer should be processed for LoRA/correction."""
lower_key = key.lower()
# Skip biases and normalization layers
if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower():
return False
if weight.numel() < 100:
return False
# Architecture-specific filtering
if architecture == "text_encoder":
return ('text' in lower_key or 'emb' in lower_key or
'encoder' in lower_key or 'attn' in lower_key)
elif architecture == "transformer":
return ('attn' in lower_key or 'transformer' in lower_key or
'mlp' in lower_key or 'to_out' in lower_key)
elif architecture == "vae":
return ('vae' in lower_key or 'encoder' in lower_key or
'decoder' in lower_key or 'conv' in lower_key)
elif architecture == "unet_conv":
return ('conv' in lower_key or 'resnet' in lower_key or
'downsample' in lower_key or 'upsample' in lower_key)
elif architecture in ["all", "auto"]:
return True
return False
def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()):
progress(0.1, desc="Starting FP8 conversion with error recovery...")
try:
def read_safetensors_metadata(path):
with open(path, 'rb') as f:
header_size = int.from_bytes(f.read(8), 'little')
header_json = f.read(header_size).decode('utf-8')
header = json.loads(header_json)
return header.get('__metadata__', {})
metadata = read_safetensors_metadata(safetensors_path)
progress(0.2, desc="Loaded metadata.")
state_dict = load_file(safetensors_path)
progress(0.4, desc="Loaded weights.")
# Auto-detect architecture if needed
if architecture == "auto":
model_keys = " ".join(state_dict.keys()).lower()
if "vae" in model_keys or ("encoder" in model_keys and "decoder" in model_keys):
architecture = "vae"
elif "text" in model_keys or "emb" in model_keys:
architecture = "text_encoder"
elif "attn" in model_keys or "transformer" in model_keys:
architecture = "transformer"
elif "conv" in model_keys or "resnet" in model_keys:
architecture = "unet_conv"
else:
architecture = "all"
settings = get_architecture_settings(architecture, lora_rank)
fp8_dtype = get_fp8_dtype(fp8_format)
sd_fp8 = {}
lora_weights = {}
correction_factors = {}
stats = {
"total_layers": len(state_dict),
"eligible_layers": 0,
"layers_with_error": 0,
"processed_layers": 0,
"correction_layers": 0,
"skipped_layers": [],
"architecture": architecture,
"method": settings["method"],
"error_magnitudes": []
}
total = len(state_dict)
for i, key in enumerate(state_dict):
progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...")
weight = state_dict[key]
if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]:
# Quantize to FP8 and calculate error
weight_fp8, error = quantize_and_get_error(weight, fp8_dtype)
sd_fp8[key] = weight_fp8
# Calculate error magnitude
error_norm = torch.norm(error.float())
weight_norm = torch.norm(weight.float())
relative_error = (error_norm / weight_norm).item() if weight_norm > 0 else 0
stats["error_magnitudes"].append({
"key": key,
"relative_error": relative_error
})
# Check if layer should be processed
should_process = should_process_layer(key, weight, architecture)
if should_process:
stats["eligible_layers"] += 1
# Only process if error is significant
if relative_error > settings["error_threshold"]:
stats["layers_with_error"] += 1
if settings["method"] == "correction":
# Use correction factors for VAE
correction = extract_correction_factors(weight, weight_fp8)
if correction is not None:
correction_factors[f"correction.{key}"] = correction
stats["correction_layers"] += 1
stats["processed_layers"] += 1
else:
# Use LoRA decomposition for other architectures
try:
A, B = low_rank_decomposition_error(
error,
rank=settings["rank"],
min_error_threshold=settings["error_threshold"]
)
if A is not None and B is not None:
lora_weights[f"lora_A.{key}"] = A.to(torch.float16)
lora_weights[f"lora_B.{key}"] = B.to(torch.float16)
stats["processed_layers"] += 1
else:
stats["skipped_layers"].append(f"{key}: decomposition failed")
except Exception as e:
stats["skipped_layers"].append(f"{key}: error - {str(e)}")
else:
stats["skipped_layers"].append(f"{key}: error too small ({relative_error:.6f})")
else:
sd_fp8[key] = weight
stats["skipped_layers"].append(f"{key}: non-float dtype")
# Calculate average error
if stats["error_magnitudes"]:
errors = [e["relative_error"] for e in stats["error_magnitudes"]]
stats["avg_error"] = sum(errors) / len(errors) if errors else 0
stats["max_error"] = max(errors) if errors else 0
base_name = os.path.splitext(os.path.basename(safetensors_path))[0]
fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors")
save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata})
# Save precision recovery weights
if lora_weights:
lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors")
lora_metadata = {
"format": "pt",
"lora_rank": str(lora_rank),
"architecture": architecture,
"stats": json.dumps(stats),
"method": "lora"
}
save_file(lora_weights, lora_path, metadata=lora_metadata)
if correction_factors:
correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors")
correction_metadata = {
"format": "pt",
"architecture": architecture,
"stats": json.dumps(stats),
"method": "correction"
}
save_file(correction_factors, correction_path, metadata=correction_metadata)
progress(0.9, desc="Saved FP8 and precision recovery files.")
progress(1.0, desc="✅ FP8 + precision recovery extraction complete!")
stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n"
stats_msg += f"Architecture: {architecture}\n"
stats_msg += f"Method: {settings['method']}\n"
stats_msg += f"Average quantization error: {stats.get('avg_error', 0):.6f}\n"
if settings["method"] == "correction":
stats_msg += f"Correction factors generated for {stats['correction_layers']} layers."
else:
stats_msg += f"LoRA generated for {stats['processed_layers']}/{stats['eligible_layers']} eligible layers (rank {lora_rank})."
if stats['processed_layers'] == 0 and stats['correction_layers'] == 0:
stats_msg += "\n⚠️ No precision recovery weights were generated. FP8 quantization error may be too small."
return True, stats_msg, stats
except Exception as e:
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
return False, error_msg, None
def parse_hf_url(url):
url = url.strip().rstrip("/")
if not url.startswith("https://huggingface.co/"):
raise ValueError("URL must start with https://huggingface.co/")
path = url.replace("https://huggingface.co/", "")
parts = path.split("/")
if len(parts) < 2:
raise ValueError("Invalid repo format")
repo_id = "/".join(parts[:2])
subfolder = ""
if len(parts) > 3 and parts[2] == "tree":
subfolder = "/".join(parts[4:]) if len(parts) > 4 else ""
elif len(parts) > 2:
subfolder = "/".join(parts[2:])
return repo_id, subfolder
def download_safetensors_file(source_type, repo_url, filename, hf_token=None, progress=gr.Progress()):
temp_dir = tempfile.mkdtemp()
try:
if source_type == "huggingface":
repo_id, subfolder = parse_hf_url(repo_url)
safetensors_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
subfolder=subfolder or None,
cache_dir=temp_dir,
token=hf_token,
resume_download=True
)
elif source_type == "modelscope":
if not MODELScope_AVAILABLE:
raise ImportError("ModelScope not installed")
repo_id = repo_url.strip()
safetensors_path = ms_file_download(model_id=repo_id, file_path=filename)
else:
raise ValueError("Unknown source")
return safetensors_path, temp_dir
except Exception as e:
shutil.rmtree(temp_dir, ignore_errors=True)
raise e
def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False):
if target_type == "huggingface":
api = HfApi(token=hf_token)
api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True)
api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token)
return f"https://huggingface.co/{new_repo_id}"
elif target_type == "modelscope":
api = ModelScopeApi()
if modelscope_token:
api.login(modelscope_token)
api.push_model(model_id=new_repo_id, model_dir=output_dir)
return f"https://modelscope.cn/models/{new_repo_id}"
else:
raise ValueError("Unknown target")
def process_and_upload_fp8(
source_type,
repo_url,
safetensors_filename,
fp8_format,
lora_rank,
architecture,
target_type,
new_repo_id,
hf_token,
modelscope_token,
private_repo,
progress=gr.Progress()
):
if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id):
return None, "❌ Invalid repo ID format. Use 'username/model-name'.", ""
if source_type == "huggingface" and not hf_token:
return None, "❌ Hugging Face token required for source.", ""
if target_type == "huggingface" and not hf_token:
return None, "❌ Hugging Face token required for target.", ""
if lora_rank < 8:
return None, "❌ LoRA rank must be at least 8.", ""
temp_dir = None
output_dir = tempfile.mkdtemp()
try:
progress(0.05, desc="Downloading model...")
safetensors_path, temp_dir = download_safetensors_file(
source_type, repo_url, safetensors_filename, hf_token, progress
)
progress(0.25, desc="Converting to FP8 with precision recovery...")
success, msg, stats = convert_safetensors_to_fp8_with_lora(
safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress
)
if not success:
return None, f"❌ Conversion failed: {msg}", ""
progress(0.9, desc="Uploading...")
repo_url_final = upload_to_target(
target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo
)
base_name = os.path.splitext(safetensors_filename)[0]
fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors"
# Determine which precision recovery file was generated
precision_recovery_file = ""
precision_recovery_type = ""
if stats.get("method") == "correction" and stats.get("correction_layers", 0) > 0:
precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors"
precision_recovery_type = "Correction Factors"
elif stats.get("method") == "lora" and stats.get("processed_layers", 0) > 0:
precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors"
precision_recovery_type = "LoRA"
readme = f"""---
library_name: diffusers
tags:
- fp8
- safetensors
- precision-recovery
- diffusion
- converted-by-gradio
---
# FP8 Model with Precision Recovery
- **Source**: `{repo_url}`
- **File**: `{safetensors_filename}`
- **FP8 Format**: `{fp8_format.upper()}`
- **Architecture**: {architecture}
- **Precision Recovery Type**: {precision_recovery_type}
- **Precision Recovery File**: `{precision_recovery_file}` if available
- **FP8 File**: `{fp8_filename}`
## Usage (Inference)
```python
from safetensors.torch import load_file
import torch
# Load FP8 model
fp8_state = load_file("{fp8_filename}")
# Load precision recovery file if available
recovery_state = {{}}
if "{precision_recovery_file}":
recovery_state = load_file("{precision_recovery_file}")
# Reconstruct high-precision weights
reconstructed = {{}}
for key in fp8_state:
# Dequantize FP8 to target precision
fp_weight = fp8_state[key].to(torch.float32)
if recovery_state:
# For LoRA approach
if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state:
A = recovery_state[f"lora_A.{{key}}"].to(torch.float32)
B = recovery_state[f"lora_B.{{key}}"].to(torch.float32)
error_correction = B @ A
reconstructed[key] = fp_weight + error_correction
# For correction factor approach
elif f"correction.{{key}}" in recovery_state:
correction = recovery_state[f"correction.{{key}}"].to(torch.float32)
reconstructed[key] = fp_weight + correction
else:
reconstructed[key] = fp_weight
else:
reconstructed[key] = fp_weight
print("Model reconstructed with FP8 error recovery")
```
> **Note**: This precision recovery targets FP8 quantization errors.
> Average quantization error: {stats.get('avg_error', 0):.6f}
"""
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.write(readme)
if target_type == "huggingface":
HfApi(token=hf_token).upload_file(
path_or_fileobj=os.path.join(output_dir, "README.md"),
path_in_repo="README.md",
repo_id=new_repo_id,
repo_type="model",
token=hf_token
)
progress(1.0, desc="✅ Done!")
result_html = f"""
✅ Success!
Model uploaded to: {new_repo_id}
Includes: FP8 model + precision recovery ({precision_recovery_type}).
Average quantization error: {stats.get('avg_error', 0):.6f}
"""
if stats['processed_layers'] > 0 or stats['correction_layers'] > 0:
result_html += f"
Precision recovery applied to {stats['processed_layers'] + stats['correction_layers']} layers."
return gr.HTML(result_html), "✅ FP8 + precision recovery upload successful!", msg
except Exception as e:
error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}"
return None, error_msg, ""
finally:
if temp_dir:
shutil.rmtree(temp_dir, ignore_errors=True)
shutil.rmtree(output_dir, ignore_errors=True)
with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo:
gr.Markdown("# 🔄 FP8 Converter with Architecture-Specific Precision Recovery")
gr.Markdown("Convert models to **FP8** with **error-based precision recovery**.")
with gr.Row():
with gr.Column():
source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source")
repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id")
safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors")
with gr.Accordion("Advanced Settings", open=True):
fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format")
lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128,
label="LoRA Rank (for text/transformers)")
architecture = gr.Dropdown(
choices=[
("Auto-detect architecture", "auto"),
("Text Encoder (LoRA)", "text_encoder"),
("Transformer blocks (LoRA)", "transformer"),
("VAE (Correction Factors)", "vae"),
("UNet Convolutions (LoRA)", "unet_conv"),
("All layers (LoRA where applicable)", "all")
],
value="auto",
label="Target Architecture"
)
with gr.Accordion("Authentication", open=False):
hf_token = gr.Textbox(label="Hugging Face Token", type="password")
modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password",
visible=MODELScope_AVAILABLE)
with gr.Column():
target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target")
new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-precision")
private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False)
status_output = gr.Markdown()
detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10)
convert_btn = gr.Button("🚀 Convert & Upload", variant="primary")
repo_link_output = gr.HTML()
convert_btn.click(
fn=process_and_upload_fp8,
inputs=[
source_type,
repo_url,
safetensors_filename,
fp8_format,
lora_rank,
architecture,
target_type,
new_repo_id,
hf_token,
modelscope_token,
private_repo
],
outputs=[repo_link_output, status_output, detailed_log],
show_progress=True
)
gr.Examples(
examples=[
["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder",
"model.safetensors", "e5m2", 96, "text_encoder"],
["huggingface", "https://huggingface.co/stabilityai/sdxl-vae",
"diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae"],
["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main",
"unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer"]
],
inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture],
label="Example Conversions"
)
gr.Markdown("""
## 🎯 What This Tool Does
Unlike traditional LoRA fine-tuning, this tool:
1. **Quantizes** the model to FP8 (loses precision)
2. **Measures** the quantization error for each weight
3. **Extracts recovery weights** that specifically recover this error
4. **Only applies** recovery where error is significant (>0.001%)
## 💡 Recommended Settings
- **Text Encoders**: rank 64-96 (text is sensitive)
- **Transformers**: rank 96-128
- **VAE**: Uses correction factors (no rank needed)
- **UNet Convolutions**: rank 32-64
## ⚠️ Important Notes
- This recovers **FP8 quantization errors**, not fine-tuning changes
- If FP8 error is tiny (<0.0001%), recovery may not be generated
- Higher rank ≠ better for error recovery (use recommended ranges)
""")
demo.launch()