Spaces:
Running
on
Zero
Running
on
Zero
| # infer.py | |
| # A command-line inference script to test the FluxMoDTilingPipeline. | |
| # This script runs the first example from the Gradio app to verify functionality | |
| # and observe the progress bar in the terminal. | |
| import os | |
| import torch | |
| import time | |
| # Make sure flux_pipeline_mod.py is in the same directory | |
| from flux_pipeline_mod import FluxMoDTilingPipeline | |
| # Conditional MMGP Setup based on Environment Variable | |
| USE_MMGP_ENV = os.getenv('USE_MMGP', 'true').lower() | |
| USE_MMGP = USE_MMGP_ENV not in ('false', '0', 'no', 'none') | |
| # Optional: for memory offloading | |
| if USE_MMGP: | |
| try: | |
| from mmgp import offload, profile_type | |
| except ImportError: | |
| print("Warning: 'mmgp' library not found. Offload will not be applied.") | |
| offload = None | |
| else: | |
| print("INFO: MMGP is disabled.") | |
| def main(): | |
| """Main function to run the inference process.""" | |
| # 1. Load Model | |
| print("--- 1. Loading Model ---") | |
| # !!! IMPORTANT: Make sure this path is correct for your system !!! | |
| #MODEL_PATH = "F:\\Models\\FLUX.1-schnell" | |
| MODEL_PATH = "black-forest-labs/FLUX.1-schnell" | |
| start_load_time = time.time() | |
| if USE_MMGP: | |
| pipe = FluxMoDTilingPipeline.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| else: | |
| pipe = FluxMoDTilingPipeline.from_pretrained( | |
| MODEL_PATH, | |
| torch_dtype=torch.bfloat16 | |
| ).to("cuda") | |
| # Apply memory optimization | |
| if offload: | |
| print("Applying LowRAM_LowVRAM offload profile...") | |
| offload.profile(pipe, profile_type.LowRAM_LowVRAM) | |
| else: | |
| print("Attempting to use the standard Diffusers offload...") | |
| try: | |
| pipe.enable_model_cpu_offload() | |
| except Exception as e: | |
| print(f"Could not apply standard offload: {e}") | |
| # The pipeline moves components to the GPU when needed. | |
| # We can explicitly move the VAE to the GPU for decoding at the end. | |
| end_load_time = time.time() | |
| print(f"Pipeline loaded successfully in {end_load_time - start_load_time:.2f} seconds.") | |
| # 2. Set Up Inference Parameters | |
| print("\n--- 2. Setting Up Inference Parameters (from Gradio Example 1) ---") | |
| # Prompts | |
| prompt_grid = [[ | |
| "Iron Man, repulsor rays blasting enemies in destroyed cityscape, cinematic lighting, photorealistic. Focus: Iron Man.", | |
| "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.", | |
| "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor." | |
| ]] | |
| # Tiling and Dimensions | |
| target_height = 1024 | |
| target_width = 3072 | |
| tile_overlap = 160 | |
| tile_weighting_method = "Cosine" | |
| # Generation | |
| num_inference_steps = 4 | |
| guidance_scale = 0.0 | |
| seed = 619517442 | |
| # Create a generator for reproducibility | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| print(f"Resolution: {target_width}x{target_height}, Steps: {num_inference_steps}, Guidance: {guidance_scale}") | |
| # 3. Start Inference | |
| print("\n--- 3. Starting Inference ---") | |
| start_inference_time = time.time() | |
| # The main call to the pipeline | |
| image = pipe( | |
| prompt=prompt_grid, | |
| height=target_height, | |
| width=target_width, | |
| tile_overlap=tile_overlap, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| tile_weighting_method=tile_weighting_method, | |
| num_inference_steps=num_inference_steps | |
| ).images[0] | |
| end_inference_time = time.time() | |
| print(f"\nInference finished in {end_inference_time - start_inference_time:.2f} seconds.") | |
| # 4. Save Output | |
| print("\n--- 4. Saving Output ---") | |
| output_filename = "outputs/inference_output.png" | |
| image.save(output_filename) | |
| print(f"Image successfully saved as '{output_filename}'") | |
| if __name__ == "__main__": | |
| main() |