import gradio as gr import os import tempfile import shutil import re import json import datetime from pathlib import Path from huggingface_hub import HfApi, hf_hub_download from safetensors.torch import load_file, save_file import torch import torch.nn.functional as F import traceback import math try: from modelscope.hub.file_download import model_file_download as ms_file_download from modelscope.hub.api import HubApi as ModelScopeApi MODELScope_AVAILABLE = True except ImportError: MODELScope_AVAILABLE = False def get_fp8_dtype(fp8_format): """Get torch FP8 dtype.""" if fp8_format == "e5m2": return torch.float8_e5m2 else: return torch.float8_e4m3fn def quantize_and_get_error(weight, fp8_dtype): """Quantize weight to FP8 and return both quantized weight and error.""" weight_fp8 = weight.to(fp8_dtype) weight_dequantized = weight_fp8.to(weight.dtype) error = weight - weight_dequantized return weight_fp8, error def low_rank_decomposition_error(error_tensor, rank=32, min_error_threshold=1e-6): """Decompose error tensor with proper rank reduction.""" if error_tensor.ndim not in [2, 4]: return None, None try: # Calculate error magnitude error_norm = torch.norm(error_tensor.float()) if error_norm < min_error_threshold: return None, None # For 2D tensors (linear layers) if error_tensor.ndim == 2: U, S, Vh = torch.linalg.svd(error_tensor.float(), full_matrices=False) # Calculate rank based on variance explained (keep 95% of error) total_variance = torch.sum(S ** 2) cumulative = torch.cumsum(S ** 2, dim=0) keep_components = torch.sum(cumulative <= 0.95 * total_variance).item() + 1 # Limit rank to much smaller than original max_rank = min(error_tensor.shape) actual_rank = min(rank, keep_components, max_rank // 2) if actual_rank < 2: return None, None A = Vh[:actual_rank, :].contiguous() B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous() return A, B # For 4D convolutions elif error_tensor.ndim == 4: out_ch, in_ch, kH, kW = error_tensor.shape # Reshape to 2D for decomposition error_2d = error_tensor.view(out_ch, in_ch * kH * kW) U, S, Vh = torch.linalg.svd(error_2d.float(), full_matrices=False) # Calculate rank based on variance explained (90% for conv) total_variance = torch.sum(S ** 2) cumulative = torch.cumsum(S ** 2, dim=0) keep_components = torch.sum(cumulative <= 0.90 * total_variance).item() + 1 # Use even lower rank for conv max_rank = min(error_2d.shape) actual_rank = min(rank // 2, keep_components, max_rank // 4) if actual_rank < 2: return None, None A = Vh[:actual_rank, :].contiguous() B = U[:, :actual_rank] @ torch.diag(S[:actual_rank]).contiguous() # Reshape back for convolutional format if kH == 1 and kW == 1: B = B.view(out_ch, actual_rank, 1, 1) A = A.view(actual_rank, in_ch, 1, 1) else: B = B.view(out_ch, actual_rank, 1, 1) A = A.view(actual_rank, in_ch, kH, kW) return A, B except Exception as e: print(f"Error decomposition failed: {e}") return None, None def extract_correction_factors(original_weight, fp8_weight): """Extract simple correction factors for VAE.""" with torch.no_grad(): orig = original_weight.float() quant = fp8_weight.float() error = orig - quant error_norm = torch.norm(error) orig_norm = torch.norm(orig) if orig_norm > 1e-6 and error_norm / orig_norm < 0.001: return None # For 4D tensors (VAE), compute per-channel correction if orig.ndim == 4: channel_mean = error.mean(dim=tuple(i for i in range(1, orig.ndim)), keepdim=True) return channel_mean.to(original_weight.dtype) elif orig.ndim == 2: row_mean = error.mean(dim=1, keepdim=True) return row_mean.to(original_weight.dtype) else: return error.mean().to(original_weight.dtype) def get_architecture_settings(architecture, base_rank): """Get optimal settings for different architectures.""" settings = { "text_encoder": { "rank": base_rank, "error_threshold": 5e-5, "min_rank": 8, "max_rank_factor": 0.4, "method": "lora" }, "transformer": { "rank": base_rank, "error_threshold": 1e-5, "min_rank": 12, "max_rank_factor": 0.35, "method": "lora" }, "vae": { "rank": base_rank // 2, "error_threshold": 1e-4, "min_rank": 4, "max_rank_factor": 0.3, "method": "correction" }, "unet_conv": { "rank": base_rank // 3, "error_threshold": 2e-5, "min_rank": 8, "max_rank_factor": 0.25, "method": "lora" }, "auto": { "rank": base_rank, "error_threshold": 1e-5, "min_rank": 8, "max_rank_factor": 0.3, "method": "lora" }, "all": { "rank": base_rank, "error_threshold": 1e-5, "min_rank": 8, "max_rank_factor": 0.3, "method": "lora" } } return settings.get(architecture, settings["auto"]) def should_process_layer(key, weight, architecture): """Determine if layer should be processed for LoRA/correction.""" lower_key = key.lower() # Skip biases and normalization layers if 'bias' in key or 'norm' in key.lower() or 'bn' in key.lower(): return False if weight.numel() < 100: return False # Architecture-specific filtering if architecture == "text_encoder": return ('text' in lower_key or 'emb' in lower_key or 'encoder' in lower_key or 'attn' in lower_key) elif architecture == "transformer": return ('attn' in lower_key or 'transformer' in lower_key or 'mlp' in lower_key or 'to_out' in lower_key) elif architecture == "vae": return ('vae' in lower_key or 'encoder' in lower_key or 'decoder' in lower_key or 'conv' in lower_key) elif architecture == "unet_conv": return ('conv' in lower_key or 'resnet' in lower_key or 'downsample' in lower_key or 'upsample' in lower_key) elif architecture in ["all", "auto"]: return True return False def convert_safetensors_to_fp8_with_lora(safetensors_path, output_dir, fp8_format, lora_rank=128, architecture="auto", progress=gr.Progress()): progress(0.1, desc="Starting FP8 conversion with error recovery...") try: def read_safetensors_metadata(path): with open(path, 'rb') as f: header_size = int.from_bytes(f.read(8), 'little') header_json = f.read(header_size).decode('utf-8') header = json.loads(header_json) return header.get('__metadata__', {}) metadata = read_safetensors_metadata(safetensors_path) progress(0.2, desc="Loaded metadata.") state_dict = load_file(safetensors_path) progress(0.4, desc="Loaded weights.") # Auto-detect architecture if needed if architecture == "auto": model_keys = " ".join(state_dict.keys()).lower() if "vae" in model_keys or ("encoder" in model_keys and "decoder" in model_keys): architecture = "vae" elif "text" in model_keys or "emb" in model_keys: architecture = "text_encoder" elif "attn" in model_keys or "transformer" in model_keys: architecture = "transformer" elif "conv" in model_keys or "resnet" in model_keys: architecture = "unet_conv" else: architecture = "all" settings = get_architecture_settings(architecture, lora_rank) fp8_dtype = get_fp8_dtype(fp8_format) sd_fp8 = {} lora_weights = {} correction_factors = {} stats = { "total_layers": len(state_dict), "eligible_layers": 0, "layers_with_error": 0, "processed_layers": 0, "correction_layers": 0, "skipped_layers": [], "architecture": architecture, "method": settings["method"], "error_magnitudes": [] } total = len(state_dict) for i, key in enumerate(state_dict): progress(0.4 + 0.4 * (i / total), desc=f"Processing {i+1}/{total}...") weight = state_dict[key] if weight.dtype in [torch.float16, torch.float32, torch.bfloat16]: # Quantize to FP8 and calculate error weight_fp8, error = quantize_and_get_error(weight, fp8_dtype) sd_fp8[key] = weight_fp8 # Calculate error magnitude error_norm = torch.norm(error.float()) weight_norm = torch.norm(weight.float()) relative_error = (error_norm / weight_norm).item() if weight_norm > 0 else 0 stats["error_magnitudes"].append({ "key": key, "relative_error": relative_error }) # Check if layer should be processed should_process = should_process_layer(key, weight, architecture) if should_process: stats["eligible_layers"] += 1 # Only process if error is significant if relative_error > settings["error_threshold"]: stats["layers_with_error"] += 1 if settings["method"] == "correction": # Use correction factors for VAE correction = extract_correction_factors(weight, weight_fp8) if correction is not None: correction_factors[f"correction.{key}"] = correction stats["correction_layers"] += 1 stats["processed_layers"] += 1 else: # Use LoRA decomposition for other architectures try: A, B = low_rank_decomposition_error( error, rank=settings["rank"], min_error_threshold=settings["error_threshold"] ) if A is not None and B is not None: lora_weights[f"lora_A.{key}"] = A.to(torch.float16) lora_weights[f"lora_B.{key}"] = B.to(torch.float16) stats["processed_layers"] += 1 else: stats["skipped_layers"].append(f"{key}: decomposition failed") except Exception as e: stats["skipped_layers"].append(f"{key}: error - {str(e)}") else: stats["skipped_layers"].append(f"{key}: error too small ({relative_error:.6f})") else: sd_fp8[key] = weight stats["skipped_layers"].append(f"{key}: non-float dtype") # Calculate average error if stats["error_magnitudes"]: errors = [e["relative_error"] for e in stats["error_magnitudes"]] stats["avg_error"] = sum(errors) / len(errors) if errors else 0 stats["max_error"] = max(errors) if errors else 0 base_name = os.path.splitext(os.path.basename(safetensors_path))[0] fp8_path = os.path.join(output_dir, f"{base_name}-fp8-{fp8_format}.safetensors") save_file(sd_fp8, fp8_path, metadata={"format": "pt", "fp8_format": fp8_format, **metadata}) # Save precision recovery weights if lora_weights: lora_path = os.path.join(output_dir, f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors") lora_metadata = { "format": "pt", "lora_rank": str(lora_rank), "architecture": architecture, "stats": json.dumps(stats), "method": "lora" } save_file(lora_weights, lora_path, metadata=lora_metadata) if correction_factors: correction_path = os.path.join(output_dir, f"{base_name}-correction-{architecture}.safetensors") correction_metadata = { "format": "pt", "architecture": architecture, "stats": json.dumps(stats), "method": "correction" } save_file(correction_factors, correction_path, metadata=correction_metadata) progress(0.9, desc="Saved FP8 and precision recovery files.") progress(1.0, desc="✅ FP8 + precision recovery extraction complete!") stats_msg = f"FP8 ({fp8_format}) with precision recovery saved.\n" stats_msg += f"Architecture: {architecture}\n" stats_msg += f"Method: {settings['method']}\n" stats_msg += f"Average quantization error: {stats.get('avg_error', 0):.6f}\n" if settings["method"] == "correction": stats_msg += f"Correction factors generated for {stats['correction_layers']} layers." else: stats_msg += f"LoRA generated for {stats['processed_layers']}/{stats['eligible_layers']} eligible layers (rank {lora_rank})." if stats['processed_layers'] == 0 and stats['correction_layers'] == 0: stats_msg += "\n⚠️ No precision recovery weights were generated. FP8 quantization error may be too small." return True, stats_msg, stats except Exception as e: error_msg = f"Error: {str(e)}\n{traceback.format_exc()}" return False, error_msg, None def parse_hf_url(url): url = url.strip().rstrip("/") if not url.startswith("https://huggingface.co/"): raise ValueError("URL must start with https://huggingface.co/") path = url.replace("https://huggingface.co/", "") parts = path.split("/") if len(parts) < 2: raise ValueError("Invalid repo format") repo_id = "/".join(parts[:2]) subfolder = "" if len(parts) > 3 and parts[2] == "tree": subfolder = "/".join(parts[4:]) if len(parts) > 4 else "" elif len(parts) > 2: subfolder = "/".join(parts[2:]) return repo_id, subfolder def download_safetensors_file(source_type, repo_url, filename, hf_token=None, progress=gr.Progress()): temp_dir = tempfile.mkdtemp() try: if source_type == "huggingface": repo_id, subfolder = parse_hf_url(repo_url) safetensors_path = hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder or None, cache_dir=temp_dir, token=hf_token, resume_download=True ) elif source_type == "modelscope": if not MODELScope_AVAILABLE: raise ImportError("ModelScope not installed") repo_id = repo_url.strip() safetensors_path = ms_file_download(model_id=repo_id, file_path=filename) else: raise ValueError("Unknown source") return safetensors_path, temp_dir except Exception as e: shutil.rmtree(temp_dir, ignore_errors=True) raise e def upload_to_target(target_type, new_repo_id, output_dir, fp8_format, hf_token=None, modelscope_token=None, private_repo=False): if target_type == "huggingface": api = HfApi(token=hf_token) api.create_repo(repo_id=new_repo_id, private=private_repo, repo_type="model", exist_ok=True) api.upload_folder(repo_id=new_repo_id, folder_path=output_dir, repo_type="model", token=hf_token) return f"https://huggingface.co/{new_repo_id}" elif target_type == "modelscope": api = ModelScopeApi() if modelscope_token: api.login(modelscope_token) api.push_model(model_id=new_repo_id, model_dir=output_dir) return f"https://modelscope.cn/models/{new_repo_id}" else: raise ValueError("Unknown target") def process_and_upload_fp8( source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo, progress=gr.Progress() ): if not re.match(r"^[a-zA-Z0-9._-]+/[a-zA-Z0-9._-]+$", new_repo_id): return None, "❌ Invalid repo ID format. Use 'username/model-name'.", "" if source_type == "huggingface" and not hf_token: return None, "❌ Hugging Face token required for source.", "" if target_type == "huggingface" and not hf_token: return None, "❌ Hugging Face token required for target.", "" if lora_rank < 8: return None, "❌ LoRA rank must be at least 8.", "" temp_dir = None output_dir = tempfile.mkdtemp() try: progress(0.05, desc="Downloading model...") safetensors_path, temp_dir = download_safetensors_file( source_type, repo_url, safetensors_filename, hf_token, progress ) progress(0.25, desc="Converting to FP8 with precision recovery...") success, msg, stats = convert_safetensors_to_fp8_with_lora( safetensors_path, output_dir, fp8_format, lora_rank, architecture, progress ) if not success: return None, f"❌ Conversion failed: {msg}", "" progress(0.9, desc="Uploading...") repo_url_final = upload_to_target( target_type, new_repo_id, output_dir, fp8_format, hf_token, modelscope_token, private_repo ) base_name = os.path.splitext(safetensors_filename)[0] fp8_filename = f"{base_name}-fp8-{fp8_format}.safetensors" # Determine which precision recovery file was generated precision_recovery_file = "" precision_recovery_type = "" if stats.get("method") == "correction" and stats.get("correction_layers", 0) > 0: precision_recovery_file = f"{base_name}-correction-{architecture}.safetensors" precision_recovery_type = "Correction Factors" elif stats.get("method") == "lora" and stats.get("processed_layers", 0) > 0: precision_recovery_file = f"{base_name}-lora-r{lora_rank}-{architecture}.safetensors" precision_recovery_type = "LoRA" readme = f"""--- library_name: diffusers tags: - fp8 - safetensors - precision-recovery - diffusion - converted-by-gradio --- # FP8 Model with Precision Recovery - **Source**: `{repo_url}` - **File**: `{safetensors_filename}` - **FP8 Format**: `{fp8_format.upper()}` - **Architecture**: {architecture} - **Precision Recovery Type**: {precision_recovery_type} - **Precision Recovery File**: `{precision_recovery_file}` if available - **FP8 File**: `{fp8_filename}` ## Usage (Inference) ```python from safetensors.torch import load_file import torch # Load FP8 model fp8_state = load_file("{fp8_filename}") # Load precision recovery file if available recovery_state = {{}} if "{precision_recovery_file}": recovery_state = load_file("{precision_recovery_file}") # Reconstruct high-precision weights reconstructed = {{}} for key in fp8_state: # Dequantize FP8 to target precision fp_weight = fp8_state[key].to(torch.float32) if recovery_state: # For LoRA approach if f"lora_A.{{key}}" in recovery_state and f"lora_B.{{key}}" in recovery_state: A = recovery_state[f"lora_A.{{key}}"].to(torch.float32) B = recovery_state[f"lora_B.{{key}}"].to(torch.float32) error_correction = B @ A reconstructed[key] = fp_weight + error_correction # For correction factor approach elif f"correction.{{key}}" in recovery_state: correction = recovery_state[f"correction.{{key}}"].to(torch.float32) reconstructed[key] = fp_weight + correction else: reconstructed[key] = fp_weight else: reconstructed[key] = fp_weight print("Model reconstructed with FP8 error recovery") ``` > **Note**: This precision recovery targets FP8 quantization errors. > Average quantization error: {stats.get('avg_error', 0):.6f} """ with open(os.path.join(output_dir, "README.md"), "w") as f: f.write(readme) if target_type == "huggingface": HfApi(token=hf_token).upload_file( path_or_fileobj=os.path.join(output_dir, "README.md"), path_in_repo="README.md", repo_id=new_repo_id, repo_type="model", token=hf_token ) progress(1.0, desc="✅ Done!") result_html = f""" ✅ Success! Model uploaded to: {new_repo_id} Includes: FP8 model + precision recovery ({precision_recovery_type}). Average quantization error: {stats.get('avg_error', 0):.6f} """ if stats['processed_layers'] > 0 or stats['correction_layers'] > 0: result_html += f"
Precision recovery applied to {stats['processed_layers'] + stats['correction_layers']} layers." return gr.HTML(result_html), "✅ FP8 + precision recovery upload successful!", msg except Exception as e: error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" return None, error_msg, "" finally: if temp_dir: shutil.rmtree(temp_dir, ignore_errors=True) shutil.rmtree(output_dir, ignore_errors=True) with gr.Blocks(title="FP8 + Precision Recovery Extractor") as demo: gr.Markdown("# 🔄 FP8 Converter with Architecture-Specific Precision Recovery") gr.Markdown("Convert models to **FP8** with **error-based precision recovery**.") with gr.Row(): with gr.Column(): source_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Source") repo_url = gr.Textbox(label="Repo URL or ID", placeholder="https://huggingface.co/... or modelscope-id") safetensors_filename = gr.Textbox(label="Filename", placeholder="model.safetensors") with gr.Accordion("Advanced Settings", open=True): fp8_format = gr.Radio(["e4m3fn", "e5m2"], value="e5m2", label="FP8 Format") lora_rank = gr.Slider(minimum=8, maximum=256, step=8, value=128, label="LoRA Rank (for text/transformers)") architecture = gr.Dropdown( choices=[ ("Auto-detect architecture", "auto"), ("Text Encoder (LoRA)", "text_encoder"), ("Transformer blocks (LoRA)", "transformer"), ("VAE (Correction Factors)", "vae"), ("UNet Convolutions (LoRA)", "unet_conv"), ("All layers (LoRA where applicable)", "all") ], value="auto", label="Target Architecture" ) with gr.Accordion("Authentication", open=False): hf_token = gr.Textbox(label="Hugging Face Token", type="password") modelscope_token = gr.Textbox(label="ModelScope Token (optional)", type="password", visible=MODELScope_AVAILABLE) with gr.Column(): target_type = gr.Radio(["huggingface", "modelscope"], value="huggingface", label="Target") new_repo_id = gr.Textbox(label="New Repo ID", placeholder="user/model-fp8-precision") private_repo = gr.Checkbox(label="Private Repository (HF only)", value=False) status_output = gr.Markdown() detailed_log = gr.Textbox(label="Processing Log", interactive=False, lines=10) convert_btn = gr.Button("🚀 Convert & Upload", variant="primary") repo_link_output = gr.HTML() convert_btn.click( fn=process_and_upload_fp8, inputs=[ source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture, target_type, new_repo_id, hf_token, modelscope_token, private_repo ], outputs=[repo_link_output, status_output, detailed_log], show_progress=True ) gr.Examples( examples=[ ["huggingface", "https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder", "model.safetensors", "e5m2", 96, "text_encoder"], ["huggingface", "https://huggingface.co/stabilityai/sdxl-vae", "diffusion_pytorch_model.safetensors", "e4m3fn", 64, "vae"], ["huggingface", "https://huggingface.co/Yabo/FramePainter/tree/main", "unet_diffusion_pytorch_model.safetensors", "e5m2", 128, "transformer"] ], inputs=[source_type, repo_url, safetensors_filename, fp8_format, lora_rank, architecture], label="Example Conversions" ) gr.Markdown(""" ## 🎯 What This Tool Does Unlike traditional LoRA fine-tuning, this tool: 1. **Quantizes** the model to FP8 (loses precision) 2. **Measures** the quantization error for each weight 3. **Extracts recovery weights** that specifically recover this error 4. **Only applies** recovery where error is significant (>0.001%) ## 💡 Recommended Settings - **Text Encoders**: rank 64-96 (text is sensitive) - **Transformers**: rank 96-128 - **VAE**: Uses correction factors (no rank needed) - **UNet Convolutions**: rank 32-64 ## ⚠️ Important Notes - This recovers **FP8 quantization errors**, not fine-tuning changes - If FP8 error is tiny (<0.0001%), recovery may not be generated - Higher rank ≠ better for error recovery (use recommended ranges) """) demo.launch()