File size: 4,063 Bytes
6d7fbee
682ea96
6d7fbee
 
 
682ea96
6d7fbee
 
 
682ea96
6d7fbee
682ea96
 
 
6d7fbee
 
682ea96
 
 
 
 
 
 
 
 
6d7fbee
 
 
 
 
 
682ea96
 
6d7fbee
 
682ea96
 
 
 
 
 
 
 
 
 
 
6d7fbee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682ea96
6d7fbee
 
 
 
 
 
 
682ea96
 
6d7fbee
 
 
682ea96
6d7fbee
682ea96
6d7fbee
 
 
 
 
 
 
 
 
682ea96
6d7fbee
682ea96
6d7fbee
682ea96
6d7fbee
 
 
 
 
 
 
 
682ea96
6d7fbee
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# infer.py
# A command-line inference script to test the FluxMoDTilingPipeline.
# This script runs the first example from the Gradio app to verify functionality
# and observe the progress bar in the terminal.

import os
import torch
import time

# Make sure flux_pipeline_mod.py is in the same directory
from flux_pipeline_mod import FluxMoDTilingPipeline
# Conditional MMGP Setup based on Environment Variable
USE_MMGP_ENV = os.getenv('USE_MMGP', 'true').lower()
USE_MMGP = USE_MMGP_ENV not in ('false', '0', 'no', 'none')

# Optional: for memory offloading
if USE_MMGP:
    try:
        from mmgp import offload, profile_type
    except ImportError:
        print("Warning: 'mmgp' library not found. Offload will not be applied.")
        offload = None
else:
    print("INFO: MMGP is disabled.")
    
def main():
    """Main function to run the inference process."""
    
    # 1. Load Model
    print("--- 1. Loading Model ---")
    # !!! IMPORTANT: Make sure this path is correct for your system !!!
    #MODEL_PATH = "F:\\Models\\FLUX.1-schnell" 
    MODEL_PATH = "black-forest-labs/FLUX.1-schnell" 
    
    start_load_time = time.time()
    if USE_MMGP:
        pipe = FluxMoDTilingPipeline.from_pretrained(
            MODEL_PATH, 
            torch_dtype=torch.bfloat16
        )
    else:
        pipe = FluxMoDTilingPipeline.from_pretrained(
            MODEL_PATH, 
            torch_dtype=torch.bfloat16
        ).to("cuda")
        
    # Apply memory optimization
    if offload:
        print("Applying LowRAM_LowVRAM offload profile...")
        offload.profile(pipe, profile_type.LowRAM_LowVRAM)
    else:
        print("Attempting to use the standard Diffusers offload...")
        try:
            pipe.enable_model_cpu_offload()
        except Exception as e:
            print(f"Could not apply standard offload: {e}")
            
    # The pipeline moves components to the GPU when needed.
    # We can explicitly move the VAE to the GPU for decoding at the end.
    
    end_load_time = time.time()
    print(f"Pipeline loaded successfully in {end_load_time - start_load_time:.2f} seconds.")

    # 2. Set Up Inference Parameters
    print("\n--- 2. Setting Up Inference Parameters (from Gradio Example 1) ---")
    
    # Prompts
    prompt_grid = [[
        "Iron Man, repulsor rays blasting enemies in destroyed cityscape, cinematic lighting, photorealistic. Focus: Iron Man.",
        "Captain America charging forward, vibranium shield deflecting energy blasts in destroyed cityscape, cinematic composition. Focus: Captain America.",
        "Thor wielding Stormbreaker in destroyed cityscape, lightning crackling, powerful strike downwards, cinematic photography. Focus: Thor."
    ]]
        
    # Tiling and Dimensions
    target_height = 1024
    target_width = 3072
    tile_overlap = 160
    tile_weighting_method = "Cosine"
    
    # Generation
    num_inference_steps = 4
    guidance_scale = 0.0
    seed = 619517442
    
    # Create a generator for reproducibility
    generator = torch.Generator("cuda").manual_seed(seed)
    
    print(f"Resolution: {target_width}x{target_height}, Steps: {num_inference_steps}, Guidance: {guidance_scale}")

    # 3. Start Inference
    print("\n--- 3. Starting Inference ---")        
    start_inference_time = time.time()
    
    # The main call to the pipeline
    image = pipe(
        prompt=prompt_grid,
        height=target_height,
        width=target_width,        
        tile_overlap=tile_overlap,
        guidance_scale=guidance_scale,
        generator=generator,
        tile_weighting_method=tile_weighting_method,
        num_inference_steps=num_inference_steps
    ).images[0]
    
    end_inference_time = time.time()
    print(f"\nInference finished in {end_inference_time - start_inference_time:.2f} seconds.")

    # 4. Save Output
    print("\n--- 4. Saving Output ---")
    output_filename = "outputs/inference_output.png"
    image.save(output_filename)
    print(f"Image successfully saved as '{output_filename}'")


if __name__ == "__main__":
    main()