# Copyright 2025, DEVAIEXP, Black Forest Labs, The HuggingFace Team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc import torch import numpy as np from enum import Enum from typing import List, Optional, Union from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, retrieve_timesteps, calculate_shift logger = logging.get_logger(__name__) def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280): width, height = image_size; aspect_ratio = width / height if aspect_ratio > 1: tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size) else: tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size) return max(tile_width, base_tile_size), max(tile_height, base_tile_size) def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]: if image_dim <= tile_dim: return [0] positions = []; current_pos = 0; stride = tile_dim - overlap while True: positions.append(current_pos) if current_pos + tile_dim >= image_dim: break current_pos += stride if current_pos > image_dim - tile_dim: break if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim) return sorted(list(set(positions))) def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height): px_row_init = tile_row_pos; px_col_init = tile_col_pos px_row_end = min(px_row_init + tile_height, image_height) px_col_end = min(px_col_init + tile_width, image_width) return px_row_init, px_row_end, px_col_init, px_col_end def _tile2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end, vae_scale_factor): return px_row_init // vae_scale_factor, px_row_end // vae_scale_factor, px_col_init // vae_scale_factor, px_col_end // vae_scale_factor def release_memory(device): gc.collect() if torch.cuda.is_available(): with torch.cuda.device(device): torch.cuda.empty_cache(); torch.cuda.synchronize() class FluxMoDTilingPipeline(FluxPipeline): class TileWeightingMethod(Enum): COSINE = "Cosine"; GAUSSIAN = "Gaussian" def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4): latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height) xx, yy = np.meshgrid(x, y) gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2)) weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32) weights_torch_target_dtype = weights_torch_f32.to(dtype) return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.config.in_channels // 4, 1, 1)) def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype): latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor x, y = np.arange(latent_width), np.arange(latent_height) mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2 x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height) return torch.tile(torch.tensor(np.outer(y_probs, x_probs), device=device, dtype=dtype), (nbatches, self.transformer.config.in_channels // 4, 1, 1)) def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype): tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object) for row, y_start in enumerate(y_steps): for col, x_start in enumerate(x_steps): _, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start if tile_weighting_method == self.TileWeightingMethod.COSINE.value: tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype) else: tile_weights[row, col] = self._generate_gaussian_weights(current_tile_w, current_tile_h, batch_size, device, dtype, sigma=tile_gaussian_sigma) return tile_weights @torch.no_grad() def __call__( self, prompt: Union[str, List[List[str]]], height: int = 1024, width: int = 1024, negative_prompt: Optional[Union[str, List[List[str]]]] = "", num_inference_steps: int = 4, guidance_scale: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, max_tile_size: int = 1024, tile_overlap: int = 256, tile_weighting_method: str = "Cosine", tile_gaussian_sigma: float = 0.4, guidance_scale_tiles: Optional[List[List[float]]] = None, max_sequence_length: int = 512, output_type: Optional[str] = "pil", return_dict: bool = True, ): device = self._execution_device batch_size = 1 is_prompt_grid = isinstance(prompt, list) and all(isinstance(row, list) for row in prompt) PIXEL_MULTIPLE = self.vae_scale_factor * 2 # 16 if is_prompt_grid: grid_rows, grid_cols = len(prompt), len(prompt[0]) tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)] y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)] logger.info(f"Prompt grid provided. Using fixed {grid_rows}x{grid_cols} grid. Actual resolution: {final_width}x{final_height}.") else: # Tiling Mode final_width, final_height = width, height tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size) tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap) x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap) grid_rows, grid_cols = len(y_steps), len(x_steps) logger.info(f"Processing image in a {grid_rows}x{grid_cols} grid of tiles.") text_embeddings = [] for r in range(grid_rows): row_embeddings = [] for c in range(grid_cols): p = prompt[r][c] if is_prompt_grid else prompt prompt_embeds, pooled, text_ids = self.encode_prompt(p, device=device, max_sequence_length=max_sequence_length) row_embeddings.append({"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled, "txt_ids": text_ids}) text_embeddings.append(row_embeddings) prompt_dtype = text_embeddings[0][0]["prompt_embeds"].dtype num_channels_latents = self.transformer.config.in_channels // 4 latents_shape = (batch_size, num_channels_latents, final_height // self.vae_scale_factor, final_width // self.vae_scale_factor) latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=prompt_dtype) image_seq_len = (tile_height // self.vae_scale_factor // 2) * (tile_width // self.vae_scale_factor // 2) mu = calculate_shift(image_seq_len); timesteps, _ = retrieve_timesteps(self.scheduler, num_inference_steps, device, mu=mu) tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype) self.text_encoder.to("cpu") self.text_encoder_2.to("cpu") release_memory(device) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): noise_preds_tiles = np.empty((grid_rows, grid_cols), dtype=object) for r, y_start in enumerate(y_steps): for c, x_start in enumerate(x_steps): px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) current_tile_pixel_height = px_r_end - px_r_init; current_tile_pixel_width = px_c_end - px_c_init r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor) tile_latents = latents[:, :, r_init:r_end, c_init:c_end] b, chan, h, w = tile_latents.shape packed_latents = self._pack_latents(tile_latents, b, chan, h, w) latent_image_ids = self._prepare_latent_image_ids(b, h//2, w//2, device, packed_latents.dtype) embeds = text_embeddings[r][c] timestep = t.expand(b).to(packed_latents.dtype) current_gs_value = guidance_scale_tiles[r][c] if (is_prompt_grid and guidance_scale_tiles) else guidance_scale current_guidance = torch.tensor([current_gs_value], device=device) if self.transformer.config.guidance_embeds else None noise_pred_packed = self.transformer( hidden_states=packed_latents, timestep=timestep / 1000, guidance=current_guidance, pooled_projections=embeds["pooled_prompt_embeds"], encoder_hidden_states=embeds["prompt_embeds"], txt_ids=embeds["txt_ids"], img_ids=latent_image_ids, )[0] noise_pred_tile = self._unpack_latents(noise_pred_packed, current_tile_pixel_height, current_tile_pixel_width, self.vae_scale_factor) noise_preds_tiles[r, c] = noise_pred_tile # Stitching and Scheduler step (no changes) noise_pred = torch.zeros_like(latents) contributors = torch.zeros_like(latents) for r, y_start in enumerate(y_steps): for c, x_start in enumerate(x_steps): px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor) noise_pred[:, :, r_init:r_end, c_init:c_end] += noise_preds_tiles[r, c] * tile_weights[r, c] contributors[:, :, r_init:r_end, c_init:c_end] += tile_weights[r, c] noise_pred /= contributors latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents)[0] if latents.dtype != latents_dtype: latents = latents.to(latents_dtype) progress_bar.update() # Post-processing if output_type == "latent": image = latents else: self.vae.to(device) latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor image = self.vae.decode(latents.to(self.vae.dtype))[0] image = self.image_processor.postprocess(image, output_type=output_type) self.maybe_free_model_hooks(); return FluxPipelineOutput(images=image)