Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright 2025, DEVAIEXP, Black Forest Labs, The HuggingFace Team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gc | |
| import torch | |
| import numpy as np | |
| from enum import Enum | |
| from typing import List, Optional, Union | |
| from diffusers.utils import logging | |
| from diffusers.utils.torch_utils import randn_tensor | |
| from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, FluxPipelineOutput, retrieve_timesteps, calculate_shift | |
| logger = logging.get_logger(__name__) | |
| def _adaptive_tile_size(image_size, base_tile_size=512, max_tile_size=1280): | |
| width, height = image_size; aspect_ratio = width / height | |
| if aspect_ratio > 1: | |
| tile_width = min(width, max_tile_size); tile_height = min(int(tile_width / aspect_ratio), max_tile_size) | |
| else: | |
| tile_height = min(height, max_tile_size); tile_width = min(int(tile_height * aspect_ratio), max_tile_size) | |
| return max(tile_width, base_tile_size), max(tile_height, base_tile_size) | |
| def _calculate_tile_positions(image_dim: int, tile_dim: int, overlap: int) -> List[int]: | |
| if image_dim <= tile_dim: return [0] | |
| positions = []; current_pos = 0; stride = tile_dim - overlap | |
| while True: | |
| positions.append(current_pos) | |
| if current_pos + tile_dim >= image_dim: break | |
| current_pos += stride | |
| if current_pos > image_dim - tile_dim: break | |
| if positions[-1] + tile_dim < image_dim: positions.append(image_dim - tile_dim) | |
| return sorted(list(set(positions))) | |
| def _tile2pixel_indices(tile_row_pos, tile_col_pos, tile_width, tile_height, image_width, image_height): | |
| px_row_init = tile_row_pos; px_col_init = tile_col_pos | |
| px_row_end = min(px_row_init + tile_height, image_height) | |
| px_col_end = min(px_col_init + tile_width, image_width) | |
| return px_row_init, px_row_end, px_col_init, px_col_end | |
| def _tile2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end, vae_scale_factor): | |
| return px_row_init // vae_scale_factor, px_row_end // vae_scale_factor, px_col_init // vae_scale_factor, px_col_end // vae_scale_factor | |
| def release_memory(device): | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| with torch.cuda.device(device): | |
| torch.cuda.empty_cache(); torch.cuda.synchronize() | |
| class FluxMoDTilingPipeline(FluxPipeline): | |
| class TileWeightingMethod(Enum): | |
| COSINE = "Cosine"; GAUSSIAN = "Gaussian" | |
| def _generate_gaussian_weights(self, tile_width, tile_height, nbatches, device, dtype, sigma=0.4): | |
| latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor | |
| x, y = np.linspace(-1, 1, latent_width), np.linspace(-1, 1, latent_height) | |
| xx, yy = np.meshgrid(x, y) | |
| gaussian_weight_np = np.exp(-(xx**2 + yy**2) / (2 * sigma**2)) | |
| weights_torch_f32 = torch.tensor(gaussian_weight_np, device=device, dtype=torch.float32) | |
| weights_torch_target_dtype = weights_torch_f32.to(dtype) | |
| return torch.tile(weights_torch_target_dtype, (nbatches, self.transformer.config.in_channels // 4, 1, 1)) | |
| def _generate_cosine_weights(self, tile_width, tile_height, nbatches, device, dtype): | |
| latent_width, latent_height = tile_width // self.vae_scale_factor, tile_height // self.vae_scale_factor | |
| x, y = np.arange(latent_width), np.arange(latent_height) | |
| mid_x, mid_y = (latent_width - 1) / 2, (latent_height - 1) / 2 | |
| x_probs, y_probs = np.cos(np.pi * (x - mid_x) / latent_width), np.cos(np.pi * (y - mid_y) / latent_height) | |
| return torch.tile(torch.tensor(np.outer(y_probs, x_probs), device=device, dtype=dtype), (nbatches, self.transformer.config.in_channels // 4, 1, 1)) | |
| def prepare_tiles_weights(self, y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, dtype): | |
| tile_weights = np.empty((len(y_steps), len(x_steps)), dtype=object) | |
| for row, y_start in enumerate(y_steps): | |
| for col, x_start in enumerate(x_steps): | |
| _, px_row_end, _, px_col_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) | |
| current_tile_h, current_tile_w = px_row_end - y_start, px_col_end - x_start | |
| if tile_weighting_method == self.TileWeightingMethod.COSINE.value: | |
| tile_weights[row, col] = self._generate_cosine_weights(current_tile_w, current_tile_h, batch_size, device, dtype) | |
| else: | |
| tile_weights[row, col] = self._generate_gaussian_weights(current_tile_w, current_tile_h, batch_size, device, dtype, sigma=tile_gaussian_sigma) | |
| return tile_weights | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[List[str]]], | |
| height: int = 1024, | |
| width: int = 1024, | |
| negative_prompt: Optional[Union[str, List[List[str]]]] = "", | |
| num_inference_steps: int = 4, | |
| guidance_scale: float = 0.0, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| max_tile_size: int = 1024, | |
| tile_overlap: int = 256, | |
| tile_weighting_method: str = "Cosine", | |
| tile_gaussian_sigma: float = 0.4, | |
| guidance_scale_tiles: Optional[List[List[float]]] = None, | |
| max_sequence_length: int = 512, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| ): | |
| device = self._execution_device | |
| batch_size = 1 | |
| is_prompt_grid = isinstance(prompt, list) and all(isinstance(row, list) for row in prompt) | |
| PIXEL_MULTIPLE = self.vae_scale_factor * 2 # 16 | |
| if is_prompt_grid: | |
| grid_rows, grid_cols = len(prompt), len(prompt[0]) | |
| tile_width = (width + (grid_cols - 1) * tile_overlap) // grid_cols | |
| tile_height = (height + (grid_rows - 1) * tile_overlap) // grid_rows | |
| tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE | |
| final_width = tile_width * grid_cols - (grid_cols - 1) * tile_overlap | |
| final_height = tile_height * grid_rows - (grid_rows - 1) * tile_overlap | |
| x_steps = [i * (tile_width - tile_overlap) for i in range(grid_cols)] | |
| y_steps = [i * (tile_height - tile_overlap) for i in range(grid_rows)] | |
| logger.info(f"Prompt grid provided. Using fixed {grid_rows}x{grid_cols} grid. Actual resolution: {final_width}x{final_height}.") | |
| else: # Tiling Mode | |
| final_width, final_height = width, height | |
| tile_width, tile_height = _adaptive_tile_size((final_width, final_height), max_tile_size=max_tile_size) | |
| tile_width -= tile_width % PIXEL_MULTIPLE; tile_height -= tile_height % PIXEL_MULTIPLE | |
| y_steps = _calculate_tile_positions(final_height, tile_height, tile_overlap) | |
| x_steps = _calculate_tile_positions(final_width, tile_width, tile_overlap) | |
| grid_rows, grid_cols = len(y_steps), len(x_steps) | |
| logger.info(f"Processing image in a {grid_rows}x{grid_cols} grid of tiles.") | |
| text_embeddings = [] | |
| for r in range(grid_rows): | |
| row_embeddings = [] | |
| for c in range(grid_cols): | |
| p = prompt[r][c] if is_prompt_grid else prompt | |
| prompt_embeds, pooled, text_ids = self.encode_prompt(p, device=device, max_sequence_length=max_sequence_length) | |
| row_embeddings.append({"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled, "txt_ids": text_ids}) | |
| text_embeddings.append(row_embeddings) | |
| prompt_dtype = text_embeddings[0][0]["prompt_embeds"].dtype | |
| num_channels_latents = self.transformer.config.in_channels // 4 | |
| latents_shape = (batch_size, num_channels_latents, final_height // self.vae_scale_factor, final_width // self.vae_scale_factor) | |
| latents = randn_tensor(latents_shape, generator=generator, device=device, dtype=prompt_dtype) | |
| image_seq_len = (tile_height // self.vae_scale_factor // 2) * (tile_width // self.vae_scale_factor // 2) | |
| mu = calculate_shift(image_seq_len); timesteps, _ = retrieve_timesteps(self.scheduler, num_inference_steps, device, mu=mu) | |
| tile_weights = self.prepare_tiles_weights(y_steps, x_steps, tile_height, tile_width, final_height, final_width, tile_weighting_method, tile_gaussian_sigma, batch_size, device, latents.dtype) | |
| self.text_encoder.to("cpu") | |
| self.text_encoder_2.to("cpu") | |
| release_memory(device) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| noise_preds_tiles = np.empty((grid_rows, grid_cols), dtype=object) | |
| for r, y_start in enumerate(y_steps): | |
| for c, x_start in enumerate(x_steps): | |
| px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) | |
| current_tile_pixel_height = px_r_end - px_r_init; current_tile_pixel_width = px_c_end - px_c_init | |
| r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor) | |
| tile_latents = latents[:, :, r_init:r_end, c_init:c_end] | |
| b, chan, h, w = tile_latents.shape | |
| packed_latents = self._pack_latents(tile_latents, b, chan, h, w) | |
| latent_image_ids = self._prepare_latent_image_ids(b, h//2, w//2, device, packed_latents.dtype) | |
| embeds = text_embeddings[r][c] | |
| timestep = t.expand(b).to(packed_latents.dtype) | |
| current_gs_value = guidance_scale_tiles[r][c] if (is_prompt_grid and guidance_scale_tiles) else guidance_scale | |
| current_guidance = torch.tensor([current_gs_value], device=device) if self.transformer.config.guidance_embeds else None | |
| noise_pred_packed = self.transformer( | |
| hidden_states=packed_latents, timestep=timestep / 1000, guidance=current_guidance, | |
| pooled_projections=embeds["pooled_prompt_embeds"], | |
| encoder_hidden_states=embeds["prompt_embeds"], | |
| txt_ids=embeds["txt_ids"], img_ids=latent_image_ids, | |
| )[0] | |
| noise_pred_tile = self._unpack_latents(noise_pred_packed, current_tile_pixel_height, current_tile_pixel_width, self.vae_scale_factor) | |
| noise_preds_tiles[r, c] = noise_pred_tile | |
| # Stitching and Scheduler step (no changes) | |
| noise_pred = torch.zeros_like(latents) | |
| contributors = torch.zeros_like(latents) | |
| for r, y_start in enumerate(y_steps): | |
| for c, x_start in enumerate(x_steps): | |
| px_r_init, px_r_end, px_c_init, px_c_end = _tile2pixel_indices(y_start, x_start, tile_width, tile_height, final_width, final_height) | |
| r_init, r_end, c_init, c_end = _tile2latent_indices(px_r_init, px_r_end, px_c_init, px_c_end, self.vae_scale_factor) | |
| noise_pred[:, :, r_init:r_end, c_init:c_end] += noise_preds_tiles[r, c] * tile_weights[r, c] | |
| contributors[:, :, r_init:r_end, c_init:c_end] += tile_weights[r, c] | |
| noise_pred /= contributors | |
| latents_dtype = latents.dtype | |
| latents = self.scheduler.step(noise_pred, t, latents)[0] | |
| if latents.dtype != latents_dtype: latents = latents.to(latents_dtype) | |
| progress_bar.update() | |
| # Post-processing | |
| if output_type == "latent": image = latents | |
| else: | |
| self.vae.to(device) | |
| latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor | |
| image = self.vae.decode(latents.to(self.vae.dtype))[0] | |
| image = self.image_processor.postprocess(image, output_type=output_type) | |
| self.maybe_free_model_hooks(); | |
| return FluxPipelineOutput(images=image) |