Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,063 Bytes
6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee 682ea96 6d7fbee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
# infer.py
# A command-line inference script to test the FluxMoDTilingPipeline.
# This script runs the first example from the Gradio app to verify functionality
# and observe the progress bar in the terminal.
import os
import torch
import time
# Make sure flux_pipeline_mod.py is in the same directory
from flux_pipeline_mod import FluxMoDTilingPipeline
# Conditional MMGP Setup based on Environment Variable
USE_MMGP_ENV = os.getenv('USE_MMGP', 'true').lower()
USE_MMGP = USE_MMGP_ENV not in ('false', '0', 'no', 'none')
# Optional: for memory offloading
if USE_MMGP:
try:
from mmgp import offload, profile_type
except ImportError:
print("Warning: 'mmgp' library not found. Offload will not be applied.")
offload = None
else:
print("INFO: MMGP is disabled.")
def main():
"""Main function to run the inference process."""
# 1. Load Model
print("--- 1. Loading Model ---")
# !!! IMPORTANT: Make sure this path is correct for your system !!!
#MODEL_PATH = "F:\\Models\\FLUX.1-schnell"
MODEL_PATH = "black-forest-labs/FLUX.1-schnell"
start_load_time = time.time()
if USE_MMGP:
pipe = FluxMoDTilingPipeline.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16
)
else:
pipe = FluxMoDTilingPipeline.from_pretrained(
MODEL_PATH,
torch_dtype=torch.bfloat16
).to("cuda")
# Apply memory optimization
if offload:
print("Applying LowRAM_LowVRAM offload profile...")
offload.profile(pipe, profile_type.LowRAM_LowVRAM)
else:
print("Attempting to use the standard Diffusers offload...")
try:
pipe.enable_model_cpu_offload()
except Exception as e:
print(f"Could not apply standard offload: {e}")
# The pipeline moves components to the GPU when needed.
# We can explicitly move the VAE to the GPU for decoding at the end.
end_load_time = time.time()
print(f"Pipeline loaded successfully in {end_load_time - start_load_time:.2f} seconds.")
# 2. Set Up Inference Parameters
print("\n--- 2. Setting Up Inference Parameters (from Gradio Example 1) ---")
# Prompts
prompt_grid = [[
"Iron Man, repulsor rays blasting enemies in destroyed cityscape, cinematic lighting, photorealistic. Focus: Iron Man.",
"Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.",
"Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor."
]]
# Tiling and Dimensions
target_height = 1024
target_width = 3072
tile_overlap = 160
tile_weighting_method = "Cosine"
# Generation
num_inference_steps = 4
guidance_scale = 0.0
seed = 619517442
# Create a generator for reproducibility
generator = torch.Generator("cuda").manual_seed(seed)
print(f"Resolution: {target_width}x{target_height}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")
# 3. Start Inference
print("\n--- 3. Starting Inference ---")
start_inference_time = time.time()
# The main call to the pipeline
image = pipe(
prompt=prompt_grid,
height=target_height,
width=target_width,
tile_overlap=tile_overlap,
guidance_scale=guidance_scale,
generator=generator,
tile_weighting_method=tile_weighting_method,
num_inference_steps=num_inference_steps
).images[0]
end_inference_time = time.time()
print(f"\nInference finished in {end_inference_time - start_inference_time:.2f} seconds.")
# 4. Save Output
print("\n--- 4. Saving Output ---")
output_filename = "outputs/inference_output.png"
image.save(output_filename)
print(f"Image successfully saved as '{output_filename}'")
if __name__ == "__main__":
main() |