From aefb8691793f5ca964691a98500960f7e242ff0d Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 30 Sep 2025 16:07:16 +0000 Subject: [PATCH 1/6] add sana controlnet implementation --- .../models/controlnets/controlnet_sana.py | 290 ++++++++++++++++++ 1 file changed, 290 insertions(+) create mode 100644 src/diffusers/models/controlnets/controlnet_sana.py diff --git a/src/diffusers/models/controlnets/controlnet_sana.py b/src/diffusers/models/controlnets/controlnet_sana.py new file mode 100644 index 000000000000..ed521adbedda --- /dev/null +++ b/src/diffusers/models/controlnets/controlnet_sana.py @@ -0,0 +1,290 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from ...configuration_utils import ConfigMixin, register_to_config +from ...loaders import PeftAdapterMixin +from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers +from ..attention_processor import AttentionProcessor +from ..embeddings import PatchEmbed, PixArtAlphaTextProjection +from ..modeling_outputs import Transformer2DModelOutput +from ..modeling_utils import ModelMixin +from ..normalization import AdaLayerNormSingle, RMSNorm +from ..transformers.sana_transformer import SanaTransformerBlock +from .controlnet import zero_module + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class SanaControlNetOutput(BaseOutput): + controlnet_block_samples: Tuple[torch.Tensor] + + +class SanaControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): + _supports_gradient_checkpointing = True + _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"] + _skip_layerwise_casting_patterns = ["patch_embed", "norm"] + + @register_to_config + def __init__( + self, + in_channels: int = 32, + out_channels: Optional[int] = 32, + num_attention_heads: int = 70, + attention_head_dim: int = 32, + num_layers: int = 7, + num_cross_attention_heads: Optional[int] = 20, + cross_attention_head_dim: Optional[int] = 112, + cross_attention_dim: Optional[int] = 2240, + caption_channels: int = 2304, + mlp_ratio: float = 2.5, + dropout: float = 0.0, + attention_bias: bool = False, + sample_size: int = 32, + patch_size: int = 1, + norm_elementwise_affine: bool = False, + norm_eps: float = 1e-6, + interpolation_scale: Optional[int] = None, + ) -> None: + super().__init__() + + out_channels = out_channels or in_channels + inner_dim = num_attention_heads * attention_head_dim + + # 1. Patch Embedding + self.patch_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + pos_embed_type="sincos" if interpolation_scale is not None else None, + ) + + # 2. Additional condition embeddings + self.time_embed = AdaLayerNormSingle(inner_dim) + + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) + + # 3. Transformer blocks + self.transformer_blocks = nn.ModuleList( + [ + SanaTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + num_cross_attention_heads=num_cross_attention_heads, + cross_attention_head_dim=cross_attention_head_dim, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + mlp_ratio=mlp_ratio, + ) + for _ in range(num_layers) + ] + ) + + # controlnet_blocks + self.controlnet_blocks = nn.ModuleList([]) + + self.input_block = zero_module(nn.Linear(inner_dim, inner_dim)) + for _ in range(len(self.transformer_blocks)): + controlnet_block = nn.Linear(inner_dim, inner_dim) + controlnet_block = zero_module(controlnet_block) + self.controlnet_blocks.append(controlnet_block) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: torch.LongTensor, + controlnet_cond: torch.Tensor, + conditioning_scale: float = 1.0, + encoder_attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + return_dict: bool = True, + ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + batch_size, num_channels, height, width = hidden_states.shape + p = self.config.patch_size + post_patch_height, post_patch_width = height // p, width // p + + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + self.input_block(self.patch_embed(controlnet_cond.to(hidden_states.dtype))) + + timestep, embedded_timestep = self.time_embed( + timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + encoder_hidden_states = self.caption_norm(encoder_hidden_states) + + # 2. Transformer blocks + block_res_samples = () + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.transformer_blocks: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + post_patch_height, + post_patch_width, + ) + block_res_samples = block_res_samples + (hidden_states,) + else: + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + post_patch_height, + post_patch_width, + ) + block_res_samples = block_res_samples + (hidden_states,) + + # 3. ControlNet blocks + controlnet_block_res_samples = () + for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks): + block_res_sample = controlnet_block(block_res_sample) + controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples] + + if not return_dict: + return (controlnet_block_res_samples,) + + return SanaControlNetOutput(controlnet_block_samples=controlnet_block_res_samples) From 09ee6d8ae3ed1be0a18a0bb0b23e7c3beba60be7 Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 30 Sep 2025 21:26:04 +0000 Subject: [PATCH 2/6] update __init__ files to add sana --- src/diffusers/__init__.py | 136 +- src/diffusers/models/__init__.py | 94 +- src/diffusers/models/controlnets/__init__.py | 20 +- .../models/transformers/file-changes.diff | 929 +++++++++++++ src/diffusers/pipelines/__init__.py | 142 +- src/diffusers/pipelines/sana/__init__.py | 3 +- .../sana/pipeline_sana_controlnet.py | 1236 +++++++++++++++++ 7 files changed, 2457 insertions(+), 103 deletions(-) create mode 100644 src/diffusers/models/transformers/file-changes.diff create mode 100644 src/diffusers/pipelines/sana/pipeline_sana_controlnet.py diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 9304c34b4e01..4d373b2a5ded 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -23,7 +23,6 @@ is_transformers_available, ) - # Lazy Import based on # https://github.com/huggingface/transformers/blob/main/src/transformers/__init__.py @@ -60,7 +59,11 @@ } try: - if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available(): + if ( + not is_torch_available() + and not is_accelerate_available() + and not is_bitsandbytes_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_bitsandbytes_objects @@ -72,7 +75,11 @@ _import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig") try: - if not is_torch_available() and not is_accelerate_available() and not is_gguf_available(): + if ( + not is_torch_available() + and not is_accelerate_available() + and not is_gguf_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_gguf_objects @@ -84,7 +91,11 @@ _import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig") try: - if not is_torch_available() and not is_accelerate_available() and not is_torchao_available(): + if ( + not is_torch_available() + and not is_accelerate_available() + and not is_torchao_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torchao_objects @@ -96,7 +107,11 @@ _import_structure["quantizers.quantization_config"].append("TorchAoConfig") try: - if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): + if ( + not is_torch_available() + and not is_accelerate_available() + and not is_optimum_quanto_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_optimum_quanto_objects @@ -126,7 +141,9 @@ except OptionalDependencyNotAvailable: from .utils import dummy_pt_objects # noqa F403 - _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + _import_structure["utils.dummy_pt_objects"] = [ + name for name in dir(dummy_pt_objects) if not name.startswith("_") + ] else: _import_structure["hooks"].extend( @@ -187,6 +204,7 @@ "OmniGenTransformer2DModel", "PixArtTransformer2DModel", "PriorTransformer", + "SanaControlNetModel", "SanaTransformer2DModel", "SD3ControlNetModel", "SD3MultiControlNetModel", @@ -303,11 +321,15 @@ from .utils import dummy_torch_and_torchsde_objects # noqa F403 _import_structure["utils.dummy_torch_and_torchsde_objects"] = [ - name for name in dir(dummy_torch_and_torchsde_objects) if not name.startswith("_") + name + for name in dir(dummy_torch_and_torchsde_objects) + if not name.startswith("_") ] else: - _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend( + ["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"] + ) try: if not (is_torch_available() and is_transformers_available()): @@ -316,7 +338,9 @@ from .utils import dummy_torch_and_transformers_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_objects"] = [ - name for name in dir(dummy_torch_and_transformers_objects) if not name.startswith("_") + name + for name in dir(dummy_torch_and_transformers_objects) + if not name.startswith("_") ] else: @@ -424,6 +448,7 @@ "PixArtSigmaPAGPipeline", "PixArtSigmaPipeline", "ReduxImageEncoder", + "SanaControlNetPipeline", "SanaPAGPipeline", "SanaPipeline", "SanaSprintPipeline", @@ -517,39 +542,63 @@ ) try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_k_diffusion_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") + name + for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) + if not name.startswith("_") ] else: - _import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"]) + _import_structure["pipelines"].extend( + ["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"] + ) try: - if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_sentencepiece_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + from .utils import ( # noqa F403 + dummy_torch_and_transformers_and_sentencepiece_objects, + ) - _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_") + _import_structure[ + "utils.dummy_torch_and_transformers_and_sentencepiece_objects" + ] = [ + name + for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) + if not name.startswith("_") ] else: - _import_structure["pipelines"].extend(["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"]) + _import_structure["pipelines"].extend( + ["KolorsImg2ImgPipeline", "KolorsPAGPipeline", "KolorsPipeline"] + ) try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + if not ( + is_torch_available() and is_transformers_available() and is_onnx_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ - name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") + name + for name in dir(dummy_torch_and_transformers_and_onnx_objects) + if not name.startswith("_") ] else: @@ -571,20 +620,26 @@ from .utils import dummy_torch_and_librosa_objects # noqa F403 _import_structure["utils.dummy_torch_and_librosa_objects"] = [ - name for name in dir(dummy_torch_and_librosa_objects) if not name.startswith("_") + name + for name in dir(dummy_torch_and_librosa_objects) + if not name.startswith("_") ] else: _import_structure["pipelines"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + if not ( + is_transformers_available() and is_torch_available() and is_note_seq_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ - name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") + name + for name in dir(dummy_transformers_and_torch_and_note_seq_objects) + if not name.startswith("_") ] @@ -605,7 +660,9 @@ else: _import_structure["models.controlnets.controlnet_flax"] = ["FlaxControlNetModel"] _import_structure["models.modeling_flax_utils"] = ["FlaxModelMixin"] - _import_structure["models.unets.unet_2d_condition_flax"] = ["FlaxUNet2DConditionModel"] + _import_structure["models.unets.unet_2d_condition_flax"] = [ + "FlaxUNet2DConditionModel" + ] _import_structure["models.vae_flax"] = ["FlaxAutoencoderKL"] _import_structure["pipelines"].extend(["FlaxDiffusionPipeline"]) _import_structure["schedulers"].extend( @@ -630,7 +687,9 @@ from .utils import dummy_flax_and_transformers_objects # noqa F403 _import_structure["utils.dummy_flax_and_transformers_objects"] = [ - name for name in dir(dummy_flax_and_transformers_objects) if not name.startswith("_") + name + for name in dir(dummy_flax_and_transformers_objects) + if not name.startswith("_") ] @@ -763,6 +822,7 @@ OmniGenTransformer2DModel, PixArtTransformer2DModel, PriorTransformer, + SanaControlNetModel, SanaTransformer2DModel, SD3ControlNetModel, SD3MultiControlNetModel, @@ -979,6 +1039,7 @@ PixArtSigmaPAGPipeline, PixArtSigmaPipeline, ReduxImageEncoder, + SanaControlNetPipeline, SanaPAGPipeline, SanaPipeline, SanaSprintPipeline, @@ -1070,22 +1131,35 @@ ) try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_k_diffusion_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403 else: - from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline + from .pipelines import ( + StableDiffusionKDiffusionPipeline, + StableDiffusionXLKDiffusionPipeline, + ) try: - if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_sentencepiece_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_sentencepiece_objects import * # noqa F403 else: from .pipelines import KolorsImg2ImgPipeline, KolorsPAGPipeline, KolorsPipeline try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + if not ( + is_torch_available() and is_transformers_available() and is_onnx_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_transformers_and_onnx_objects import * # noqa F403 @@ -1108,7 +1182,11 @@ from .pipelines import AudioDiffusionPipeline, Mel try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + if not ( + is_transformers_available() + and is_torch_available() + and is_note_seq_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from .utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f7d70f1d9826..f300dda7fd32 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -21,67 +21,110 @@ is_torch_available, ) - _import_structure = {} if is_torch_available(): + _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_allegro"] = ["AutoencoderKLAllegro"] - _import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"] - _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = ["AutoencoderKLHunyuanVideo"] + _import_structure["autoencoders.autoencoder_kl_cogvideox"] = [ + "AutoencoderKLCogVideoX" + ] + _import_structure["autoencoders.autoencoder_kl_hunyuan_video"] = [ + "AutoencoderKLHunyuanVideo" + ] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] - _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] + _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = [ + "AutoencoderKLTemporalDecoder" + ] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] - _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] + _import_structure["autoencoders.consistency_decoder_vae"] = [ + "ConsistencyDecoderVAE" + ] _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] - _import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"] + _import_structure["controlnets.controlnet_flux"] = [ + "FluxControlNetModel", + "FluxMultiControlNetModel", + ] _import_structure["controlnets.controlnet_hunyuan"] = [ "HunyuanDiT2DControlNetModel", "HunyuanDiT2DMultiControlNetModel", ] - _import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"] + _import_structure["controlnets.controlnet_sd3"] = [ + "SD3ControlNetModel", + "SD3MultiControlNetModel", + ] _import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"] _import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"] - _import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"] + _import_structure["controlnets.controlnet_xs"] = [ + "ControlNetXSAdapter", + "UNetControlNetXSModel", + ] _import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"] - _import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"] + _import_structure["controlnets.multicontrolnet_union"] = [ + "MultiControlNetUnionModel" + ] _import_structure["embeddings"] = ["ImageProjection"] _import_structure["modeling_utils"] = ["ModelMixin"] - _import_structure["transformers.auraflow_transformer_2d"] = ["AuraFlowTransformer2DModel"] - _import_structure["transformers.cogvideox_transformer_3d"] = ["CogVideoXTransformer3DModel"] - _import_structure["transformers.consisid_transformer_3d"] = ["ConsisIDTransformer3DModel"] + _import_structure["transformers.auraflow_transformer_2d"] = [ + "AuraFlowTransformer2DModel" + ] + _import_structure["transformers.cogvideox_transformer_3d"] = [ + "CogVideoXTransformer3DModel" + ] + _import_structure["transformers.consisid_transformer_3d"] = [ + "ConsisIDTransformer3DModel" + ] _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"] _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"] _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"] _import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"] _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] - _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] + _import_structure["transformers.pixart_transformer_2d"] = [ + "PixArtTransformer2DModel" + ] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] _import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"] _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] - _import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"] - _import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"] - _import_structure["transformers.transformer_cogview4"] = ["CogView4Transformer2DModel"] - _import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"] + _import_structure["transformers.transformer_allegro"] = [ + "AllegroTransformer3DModel" + ] + _import_structure["transformers.transformer_cogview3plus"] = [ + "CogView3PlusTransformer2DModel" + ] + _import_structure["transformers.transformer_cogview4"] = [ + "CogView4Transformer2DModel" + ] + _import_structure["transformers.transformer_easyanimate"] = [ + "EasyAnimateTransformer3DModel" + ] _import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"] - _import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"] + _import_structure["transformers.transformer_hunyuan_video"] = [ + "HunyuanVideoTransformer3DModel" + ] _import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"] - _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] + _import_structure["transformers.transformer_lumina2"] = [ + "Lumina2Transformer2DModel" + ] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] - _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] + _import_structure["transformers.transformer_omnigen"] = [ + "OmniGenTransformer2DModel" + ] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] - _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] + _import_structure["transformers.transformer_temporal"] = [ + "TransformerTemporalModel" + ] _import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"] _import_structure["unets.unet_1d"] = ["UNet1DModel"] _import_structure["unets.unet_2d"] = ["UNet2DModel"] @@ -90,7 +133,9 @@ _import_structure["unets.unet_i2vgen_xl"] = ["I2VGenXLUNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] - _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["unets.unet_spatio_temporal_condition"] = [ + "UNetSpatioTemporalConditionModel" + ] _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"] @@ -131,6 +176,7 @@ HunyuanDiT2DMultiControlNetModel, MultiControlNetModel, MultiControlNetUnionModel, + SanaControlNetModel, SD3ControlNetModel, SD3MultiControlNetModel, SparseControlNetModel, @@ -189,4 +235,6 @@ else: import sys - sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) + sys.modules[__name__] = _LazyModule( + __name__, globals()["__file__"], _import_structure, module_spec=__spec__ + ) diff --git a/src/diffusers/models/controlnets/__init__.py b/src/diffusers/models/controlnets/__init__.py index 1dd92e51a44c..621de4329868 100644 --- a/src/diffusers/models/controlnets/__init__.py +++ b/src/diffusers/models/controlnets/__init__.py @@ -1,22 +1,34 @@ from ...utils import is_flax_available, is_torch_available - if is_torch_available(): from .controlnet import ControlNetModel, ControlNetOutput - from .controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel + from .controlnet_flux import ( + FluxControlNetModel, + FluxControlNetOutput, + FluxMultiControlNetModel, + ) from .controlnet_hunyuan import ( HunyuanControlNetOutput, HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel, ) - from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel + from .controlnet_sana import SanaControlNetModel + from .controlnet_sd3 import ( + SD3ControlNetModel, + SD3ControlNetOutput, + SD3MultiControlNetModel, + ) from .controlnet_sparsectrl import ( SparseControlNetConditioningEmbedding, SparseControlNetModel, SparseControlNetOutput, ) from .controlnet_union import ControlNetUnionModel - from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel + from .controlnet_xs import ( + ControlNetXSAdapter, + ControlNetXSOutput, + UNetControlNetXSModel, + ) from .multicontrolnet import MultiControlNetModel from .multicontrolnet_union import MultiControlNetUnionModel diff --git a/src/diffusers/models/transformers/file-changes.diff b/src/diffusers/models/transformers/file-changes.diff new file mode 100644 index 000000000000..584bb5170bb7 --- /dev/null +++ b/src/diffusers/models/transformers/file-changes.diff @@ -0,0 +1,929 @@ +diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py +index 7ab371a1a..b64920c37 100644 +--- a/src/diffusers/models/transformers/transformer_flux.py ++++ b/src/diffusers/models/transformers/transformer_flux.py +@@ -1,4 +1,4 @@ +-# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. ++# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. + # + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. +@@ -12,340 +12,65 @@ + # See the License for the specific language governing permissions and + # limitations under the License. + +-import inspect +-from typing import Any, Dict, List, Optional, Tuple, Union ++ ++from typing import Any, Dict, Optional, Tuple, Union + + import numpy as np + import torch + import torch.nn as nn +-import torch.nn.functional as F + + from ...configuration_utils import ConfigMixin, register_to_config +-from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +-from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers ++from ...loaders import ( ++ FluxTransformer2DLoadersMixin, ++ FromOriginalModelMixin, ++ PeftAdapterMixin, ++) ++from ...models.attention import FeedForward ++from ...models.attention_processor import ( ++ Attention, ++ AttentionProcessor, ++ FluxAttnProcessor2_0, ++ FluxAttnProcessor2_0_NPU, ++ FusedFluxAttnProcessor2_0, ++) ++from ...models.modeling_utils import ModelMixin ++from ...models.normalization import ( ++ AdaLayerNormContinuous, ++ AdaLayerNormZero, ++ AdaLayerNormZeroSingle, ++) ++from ...utils import ( ++ USE_PEFT_BACKEND, ++ deprecate, ++ logging, ++ scale_lora_layers, ++ unscale_lora_layers, ++) ++from ...utils.import_utils import is_torch_npu_available + from ...utils.torch_utils import maybe_allow_in_graph +-from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward +-from ..attention_dispatch import dispatch_attention_fn + from ..cache_utils import CacheMixin + from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, +- apply_rotary_emb, +- get_1d_rotary_pos_embed, ++ FluxPosEmbed, ++ TimestepEmbedding, ++ Timesteps, + ) + from ..modeling_outputs import Transformer2DModelOutput +-from ..modeling_utils import ModelMixin +-from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle +- + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +-def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): +- query = attn.to_q(hidden_states) +- key = attn.to_k(hidden_states) +- value = attn.to_v(hidden_states) +- +- encoder_query = encoder_key = encoder_value = None +- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: +- encoder_query = attn.add_q_proj(encoder_hidden_states) +- encoder_key = attn.add_k_proj(encoder_hidden_states) +- encoder_value = attn.add_v_proj(encoder_hidden_states) +- +- return query, key, value, encoder_query, encoder_key, encoder_value +- +- +-def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): +- query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) +- +- encoder_query = encoder_key = encoder_value = (None,) +- if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): +- encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) +- +- return query, key, value, encoder_query, encoder_key, encoder_value +- +- +-def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): +- if attn.fused_projections: +- return _get_fused_projections(attn, hidden_states, encoder_hidden_states) +- return _get_projections(attn, hidden_states, encoder_hidden_states) +- +- +-class FluxAttnProcessor: +- _attention_backend = None +- +- def __init__(self): +- if not hasattr(F, "scaled_dot_product_attention"): +- raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") +- +- def __call__( +- self, +- attn: "FluxAttention", +- hidden_states: torch.Tensor, +- encoder_hidden_states: torch.Tensor = None, +- attention_mask: Optional[torch.Tensor] = None, +- image_rotary_emb: Optional[torch.Tensor] = None, +- ) -> torch.Tensor: +- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( +- attn, hidden_states, encoder_hidden_states +- ) +- +- query = query.unflatten(-1, (attn.heads, -1)) +- key = key.unflatten(-1, (attn.heads, -1)) +- value = value.unflatten(-1, (attn.heads, -1)) +- +- query = attn.norm_q(query) +- key = attn.norm_k(key) +- +- if attn.added_kv_proj_dim is not None: +- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) +- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) +- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) +- +- encoder_query = attn.norm_added_q(encoder_query) +- encoder_key = attn.norm_added_k(encoder_key) +- +- query = torch.cat([encoder_query, query], dim=1) +- key = torch.cat([encoder_key, key], dim=1) +- value = torch.cat([encoder_value, value], dim=1) +- +- if image_rotary_emb is not None: +- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) +- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) +- +- hidden_states = dispatch_attention_fn( +- query, key, value, attn_mask=attention_mask, backend=self._attention_backend +- ) +- hidden_states = hidden_states.flatten(2, 3) +- hidden_states = hidden_states.to(query.dtype) +- +- if encoder_hidden_states is not None: +- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( +- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 +- ) +- hidden_states = attn.to_out[0](hidden_states) +- hidden_states = attn.to_out[1](hidden_states) +- encoder_hidden_states = attn.to_add_out(encoder_hidden_states) +- +- return hidden_states, encoder_hidden_states +- else: +- return hidden_states +- +- +-class FluxIPAdapterAttnProcessor(torch.nn.Module): +- """Flux Attention processor for IP-Adapter.""" +- +- _attention_backend = None +- +- def __init__( +- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None +- ): +- super().__init__() +- +- if not hasattr(F, "scaled_dot_product_attention"): +- raise ImportError( +- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." +- ) +- +- self.hidden_size = hidden_size +- self.cross_attention_dim = cross_attention_dim +- +- if not isinstance(num_tokens, (tuple, list)): +- num_tokens = [num_tokens] +- +- if not isinstance(scale, list): +- scale = [scale] * len(num_tokens) +- if len(scale) != len(num_tokens): +- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") +- self.scale = scale +- +- self.to_k_ip = nn.ModuleList( +- [ +- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) +- for _ in range(len(num_tokens)) +- ] +- ) +- self.to_v_ip = nn.ModuleList( +- [ +- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) +- for _ in range(len(num_tokens)) +- ] +- ) +- +- def __call__( +- self, +- attn: "FluxAttention", +- hidden_states: torch.Tensor, +- encoder_hidden_states: torch.Tensor = None, +- attention_mask: Optional[torch.Tensor] = None, +- image_rotary_emb: Optional[torch.Tensor] = None, +- ip_hidden_states: Optional[List[torch.Tensor]] = None, +- ip_adapter_masks: Optional[torch.Tensor] = None, +- ) -> torch.Tensor: +- batch_size = hidden_states.shape[0] +- +- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( +- attn, hidden_states, encoder_hidden_states +- ) +- +- query = query.unflatten(-1, (attn.heads, -1)) +- key = key.unflatten(-1, (attn.heads, -1)) +- value = value.unflatten(-1, (attn.heads, -1)) +- +- query = attn.norm_q(query) +- key = attn.norm_k(key) +- ip_query = query +- +- if encoder_hidden_states is not None: +- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) +- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) +- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) +- +- encoder_query = attn.norm_added_q(encoder_query) +- encoder_key = attn.norm_added_k(encoder_key) +- +- query = torch.cat([encoder_query, query], dim=1) +- key = torch.cat([encoder_key, key], dim=1) +- value = torch.cat([encoder_value, value], dim=1) +- +- if image_rotary_emb is not None: +- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) +- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) +- +- hidden_states = dispatch_attention_fn( +- query, +- key, +- value, +- attn_mask=attention_mask, +- dropout_p=0.0, +- is_causal=False, +- backend=self._attention_backend, +- ) +- hidden_states = hidden_states.flatten(2, 3) +- hidden_states = hidden_states.to(query.dtype) +- +- if encoder_hidden_states is not None: +- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( +- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 +- ) +- hidden_states = attn.to_out[0](hidden_states) +- hidden_states = attn.to_out[1](hidden_states) +- encoder_hidden_states = attn.to_add_out(encoder_hidden_states) +- +- # IP-adapter +- ip_attn_output = torch.zeros_like(hidden_states) +- +- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( +- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip +- ): +- ip_key = to_k_ip(current_ip_hidden_states) +- ip_value = to_v_ip(current_ip_hidden_states) +- +- ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) +- ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) +- +- current_ip_hidden_states = dispatch_attention_fn( +- ip_query, +- ip_key, +- ip_value, +- attn_mask=None, +- dropout_p=0.0, +- is_causal=False, +- backend=self._attention_backend, +- ) +- current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) +- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) +- ip_attn_output += scale * current_ip_hidden_states +- +- return hidden_states, encoder_hidden_states, ip_attn_output +- else: +- return hidden_states +- +- +-class FluxAttention(torch.nn.Module, AttentionModuleMixin): +- _default_processor_cls = FluxAttnProcessor +- _available_processors = [ +- FluxAttnProcessor, +- FluxIPAdapterAttnProcessor, +- ] +- ++@maybe_allow_in_graph ++class FluxSingleTransformerBlock(nn.Module): + def __init__( + self, +- query_dim: int, +- heads: int = 8, +- dim_head: int = 64, +- dropout: float = 0.0, +- bias: bool = False, +- added_kv_proj_dim: Optional[int] = None, +- added_proj_bias: Optional[bool] = True, +- out_bias: bool = True, +- eps: float = 1e-5, +- out_dim: int = None, +- context_pre_only: Optional[bool] = None, +- pre_only: bool = False, +- elementwise_affine: bool = True, +- processor=None, ++ dim: int, ++ num_attention_heads: int, ++ attention_head_dim: int, ++ mlp_ratio: float = 4.0, + ): + super().__init__() +- +- self.head_dim = dim_head +- self.inner_dim = out_dim if out_dim is not None else dim_head * heads +- self.query_dim = query_dim +- self.use_bias = bias +- self.dropout = dropout +- self.out_dim = out_dim if out_dim is not None else query_dim +- self.context_pre_only = context_pre_only +- self.pre_only = pre_only +- self.heads = out_dim // dim_head if out_dim is not None else heads +- self.added_kv_proj_dim = added_kv_proj_dim +- self.added_proj_bias = added_proj_bias +- +- self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) +- self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) +- self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) +- self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) +- self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) +- +- if not self.pre_only: +- self.to_out = torch.nn.ModuleList([]) +- self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) +- self.to_out.append(torch.nn.Dropout(dropout)) +- +- if added_kv_proj_dim is not None: +- self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) +- self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) +- self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) +- self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) +- self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) +- self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) +- +- if processor is None: +- processor = self._default_processor_cls() +- self.set_processor(processor) +- +- def forward( +- self, +- hidden_states: torch.Tensor, +- encoder_hidden_states: Optional[torch.Tensor] = None, +- attention_mask: Optional[torch.Tensor] = None, +- image_rotary_emb: Optional[torch.Tensor] = None, +- **kwargs, +- ) -> torch.Tensor: +- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) +- quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} +- unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] +- if len(unused_kwargs) > 0: +- logger.warning( +- f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." +- ) +- kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} +- return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) +- +- +-@maybe_allow_in_graph +-class FluxSingleTransformerBlock(nn.Module): +- def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): +- super().__init__() + self.mlp_hidden_dim = int(dim * mlp_ratio) + + self.norm = AdaLayerNormZeroSingle(dim) +@@ -353,13 +78,25 @@ class FluxSingleTransformerBlock(nn.Module): + self.act_mlp = nn.GELU(approximate="tanh") + self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) + +- self.attn = FluxAttention( ++ if is_torch_npu_available(): ++ deprecation_message = ( ++ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " ++ "should be set explicitly using the `set_attn_processor` method." ++ ) ++ deprecate("npu_processor", "0.34.0", deprecation_message) ++ processor = FluxAttnProcessor2_0_NPU() ++ else: ++ processor = FluxAttnProcessor2_0() ++ ++ self.attn = Attention( + query_dim=dim, ++ cross_attention_dim=None, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + bias=True, +- processor=FluxAttnProcessor(), ++ processor=processor, ++ qk_norm="rms_norm", + eps=1e-6, + pre_only=True, + ) +@@ -367,14 +104,11 @@ class FluxSingleTransformerBlock(nn.Module): + def forward( + self, + hidden_states: torch.Tensor, +- encoder_hidden_states: torch.Tensor, + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, +- ) -> Tuple[torch.Tensor, torch.Tensor]: +- text_seq_len = encoder_hidden_states.shape[1] +- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) +- ++ attention_mask: Optional[torch.Tensor] = None, ++ ) -> torch.Tensor: + residual = hidden_states + norm_hidden_states, gate = self.norm(hidden_states, emb=temb) + mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) +@@ -382,6 +116,7 @@ class FluxSingleTransformerBlock(nn.Module): + attn_output = self.attn( + hidden_states=norm_hidden_states, + image_rotary_emb=image_rotary_emb, ++ attention_mask=attention_mask, + **joint_attention_kwargs, + ) + +@@ -392,29 +127,35 @@ class FluxSingleTransformerBlock(nn.Module): + if hidden_states.dtype == torch.float16: + hidden_states = hidden_states.clip(-65504, 65504) + +- encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] +- return encoder_hidden_states, hidden_states ++ return hidden_states + + + @maybe_allow_in_graph + class FluxTransformerBlock(nn.Module): + def __init__( +- self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 ++ self, ++ dim: int, ++ num_attention_heads: int, ++ attention_head_dim: int, ++ qk_norm: str = "rms_norm", ++ eps: float = 1e-6, + ): + super().__init__() + + self.norm1 = AdaLayerNormZero(dim) + self.norm1_context = AdaLayerNormZero(dim) + +- self.attn = FluxAttention( ++ self.attn = Attention( + query_dim=dim, ++ cross_attention_dim=None, + added_kv_proj_dim=dim, + dim_head=attention_head_dim, + heads=num_attention_heads, + out_dim=dim, + context_pre_only=False, + bias=True, +- processor=FluxAttnProcessor(), ++ processor=FluxAttnProcessor2_0(), ++ qk_norm=qk_norm, + eps=eps, + ) + +@@ -422,7 +163,9 @@ class FluxTransformerBlock(nn.Module): + self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + + self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) +- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") ++ self.ff_context = FeedForward( ++ dim=dim, dim_out=dim, activation_fn="gelu-approximate" ++ ) + + def forward( + self, +@@ -431,19 +174,22 @@ class FluxTransformerBlock(nn.Module): + temb: torch.Tensor, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, ++ attention_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: +- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) ++ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( ++ hidden_states, emb=temb ++ ) + +- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( +- encoder_hidden_states, emb=temb ++ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( ++ self.norm1_context(encoder_hidden_states, emb=temb) + ) + joint_attention_kwargs = joint_attention_kwargs or {} +- + # Attention. + attention_outputs = self.attn( + hidden_states=norm_hidden_states, + encoder_hidden_states=norm_encoder_hidden_states, + image_rotary_emb=image_rotary_emb, ++ attention_mask=attention_mask, + **joint_attention_kwargs, + ) + +@@ -457,7 +203,9 @@ class FluxTransformerBlock(nn.Module): + hidden_states = hidden_states + attn_output + + norm_hidden_states = self.norm2(hidden_states) +- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ++ norm_hidden_states = ( ++ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] ++ ) + + ff_output = self.ff(norm_hidden_states) + ff_output = gate_mlp.unsqueeze(1) * ff_output +@@ -467,51 +215,26 @@ class FluxTransformerBlock(nn.Module): + hidden_states = hidden_states + ip_attn_output + + # Process attention outputs for the `encoder_hidden_states`. ++ + context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output + encoder_hidden_states = encoder_hidden_states + context_attn_output + + norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) +- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] ++ norm_encoder_hidden_states = ( ++ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) ++ + c_shift_mlp[:, None] ++ ) + + context_ff_output = self.ff_context(norm_encoder_hidden_states) +- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output ++ encoder_hidden_states = ( ++ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output ++ ) + if encoder_hidden_states.dtype == torch.float16: + encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) + + return encoder_hidden_states, hidden_states + + +-class FluxPosEmbed(nn.Module): +- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 +- def __init__(self, theta: int, axes_dim: List[int]): +- super().__init__() +- self.theta = theta +- self.axes_dim = axes_dim +- +- def forward(self, ids: torch.Tensor) -> torch.Tensor: +- n_axes = ids.shape[-1] +- cos_out = [] +- sin_out = [] +- pos = ids.float() +- is_mps = ids.device.type == "mps" +- is_npu = ids.device.type == "npu" +- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 +- for i in range(n_axes): +- cos, sin = get_1d_rotary_pos_embed( +- self.axes_dim[i], +- pos[:, i], +- theta=self.theta, +- repeat_interleave_real=True, +- use_real=True, +- freqs_dtype=freqs_dtype, +- ) +- cos_out.append(cos) +- sin_out.append(sin) +- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) +- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) +- return freqs_cos, freqs_sin +- +- + class FluxTransformer2DModel( + ModelMixin, + ConfigMixin, +@@ -519,7 +242,6 @@ class FluxTransformer2DModel( + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, +- AttentionMixin, + ): + """ + The Transformer model introduced in Flux. +@@ -555,7 +277,6 @@ class FluxTransformer2DModel( + _supports_gradient_checkpointing = True + _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + _skip_layerwise_casting_patterns = ["pos_embed", "norm"] +- _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] + + @register_to_config + def __init__( +@@ -570,7 +291,8 @@ class FluxTransformer2DModel( + joint_attention_dim: int = 4096, + pooled_projection_dim: int = 768, + guidance_embeds: bool = False, +- axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), ++ additional_timestep_embeds: bool = False, ++ axes_dims_rope: Tuple[int] = (16, 56, 56), + ): + super().__init__() + self.out_channels = out_channels or in_channels +@@ -579,12 +301,23 @@ class FluxTransformer2DModel( + self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) + + text_time_guidance_cls = ( +- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ++ CombinedTimestepGuidanceTextProjEmbeddings ++ if guidance_embeds ++ else CombinedTimestepTextProjEmbeddings + ) ++ + self.time_text_embed = text_time_guidance_cls( + embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim + ) + ++ if additional_timestep_embeds: ++ self.additional_time_proj = Timesteps( ++ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 ++ ) ++ self.additional_timestep_embedder = TimestepEmbedding( ++ in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False ++ ) ++ + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) + self.x_embedder = nn.Linear(in_channels, self.inner_dim) + +@@ -610,11 +343,123 @@ class FluxTransformer2DModel( + ] + ) + +- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) +- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) ++ self.norm_out = AdaLayerNormContinuous( ++ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 ++ ) ++ self.proj_out = nn.Linear( ++ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True ++ ) + + self.gradient_checkpointing = False + ++ @property ++ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors ++ def attn_processors(self) -> Dict[str, AttentionProcessor]: ++ r""" ++ Returns: ++ `dict` of attention processors: A dictionary containing all attention processors used in the model with ++ indexed by its weight name. ++ """ ++ # set recursively ++ processors = {} ++ ++ def fn_recursive_add_processors( ++ name: str, ++ module: torch.nn.Module, ++ processors: Dict[str, AttentionProcessor], ++ ): ++ if hasattr(module, "get_processor"): ++ processors[f"{name}.processor"] = module.get_processor() ++ ++ for sub_name, child in module.named_children(): ++ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) ++ ++ return processors ++ ++ for name, module in self.named_children(): ++ fn_recursive_add_processors(name, module, processors) ++ ++ return processors ++ ++ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor ++ def set_attn_processor( ++ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] ++ ): ++ r""" ++ Sets the attention processor to use to compute attention. ++ ++ Parameters: ++ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): ++ The instantiated processor class or a dictionary of processor classes that will be set as the processor ++ for **all** `Attention` layers. ++ ++ If `processor` is a dict, the key needs to define the path to the corresponding cross attention ++ processor. This is strongly recommended when setting trainable attention processors. ++ ++ """ ++ count = len(self.attn_processors.keys()) ++ ++ if isinstance(processor, dict) and len(processor) != count: ++ raise ValueError( ++ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" ++ f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ++ ) ++ ++ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): ++ if hasattr(module, "set_processor"): ++ if not isinstance(processor, dict): ++ module.set_processor(processor) ++ else: ++ module.set_processor(processor.pop(f"{name}.processor")) ++ ++ for sub_name, child in module.named_children(): ++ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) ++ ++ for name, module in self.named_children(): ++ fn_recursive_attn_processor(name, module, processor) ++ ++ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 ++ def fuse_qkv_projections(self): ++ """ ++ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) ++ are fused. For cross-attention modules, key and value projection matrices are fused. ++ ++ ++ ++ This API is 🧪 experimental. ++ ++ ++ """ ++ self.original_attn_processors = None ++ ++ for _, attn_processor in self.attn_processors.items(): ++ if "Added" in str(attn_processor.__class__.__name__): ++ raise ValueError( ++ "`fuse_qkv_projections()` is not supported for models having added KV projections." ++ ) ++ ++ self.original_attn_processors = self.attn_processors ++ ++ for module in self.modules(): ++ if isinstance(module, Attention): ++ module.fuse_projections(fuse=True) ++ ++ self.set_attn_processor(FusedFluxAttnProcessor2_0()) ++ ++ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections ++ def unfuse_qkv_projections(self): ++ """Disables the fused QKV projection if enabled. ++ ++ ++ ++ This API is 🧪 experimental. ++ ++ ++ ++ """ ++ if self.original_attn_processors is not None: ++ self.set_attn_processor(self.original_attn_processors) ++ + def forward( + self, + hidden_states: torch.Tensor, +@@ -624,11 +469,13 @@ class FluxTransformer2DModel( + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, ++ additional_timestep_embeddings: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, ++ attention_mask: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. +@@ -666,7 +513,10 @@ class FluxTransformer2DModel( + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: +- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: ++ if ( ++ joint_attention_kwargs is not None ++ and joint_attention_kwargs.get("scale", None) is not None ++ ): + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) +@@ -676,12 +526,24 @@ class FluxTransformer2DModel( + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 ++ else: ++ guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) ++ ++ if additional_timestep_embeddings is not None: ++ additional_timestep_embed_proj = self.additional_time_proj( ++ additional_timestep_embeddings ++ ) ++ additional_class_emb = self.additional_timestep_embedder( ++ additional_timestep_embed_proj.to(dtype=temb.dtype) ++ ) ++ temb = temb + additional_class_emb ++ + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: +@@ -700,20 +562,27 @@ class FluxTransformer2DModel( + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + +- if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: +- ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") ++ if ( ++ joint_attention_kwargs is not None ++ and "ip_adapter_image_embeds" in joint_attention_kwargs ++ ): ++ ip_adapter_image_embeds = joint_attention_kwargs.pop( ++ "ip_adapter_image_embeds" ++ ) + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: +- encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( +- block, +- hidden_states, +- encoder_hidden_states, +- temb, +- image_rotary_emb, +- joint_attention_kwargs, ++ encoder_hidden_states, hidden_states = ( ++ self._gradient_checkpointing_func( ++ block, ++ hidden_states, ++ encoder_hidden_states, ++ temb, ++ image_rotary_emb, ++ attention_mask=attention_mask, ++ ) + ) + + else: +@@ -723,45 +592,61 @@ class FluxTransformer2DModel( + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ++ attention_mask=attention_mask, + ) + + # controlnet residual + if controlnet_block_samples is not None: +- interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) ++ interval_control = len(self.transformer_blocks) / len( ++ controlnet_block_samples ++ ) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( +- hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] ++ hidden_states ++ + controlnet_block_samples[ ++ index_block % len(controlnet_block_samples) ++ ] + ) + else: +- hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] ++ hidden_states = ( ++ hidden_states ++ + controlnet_block_samples[index_block // interval_control] ++ ) ++ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: +- encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( ++ hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, +- encoder_hidden_states, + temb, + image_rotary_emb, +- joint_attention_kwargs, ++ attention_mask=attention_mask, + ) + + else: +- encoder_hidden_states, hidden_states = block( ++ hidden_states = block( + hidden_states=hidden_states, +- encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, ++ attention_mask=attention_mask, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: +- interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) ++ interval_control = len(self.single_transformer_blocks) / len( ++ controlnet_single_block_samples ++ ) + interval_control = int(np.ceil(interval_control)) +- hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] ++ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( ++ hidden_states[:, encoder_hidden_states.shape[1] :, ...] ++ + controlnet_single_block_samples[index_block // interval_control] ++ ) ++ ++ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index b901d42d9cf7..90c85547d652 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -16,7 +16,6 @@ is_transformers_available, ) - # These modules contain pipelines from multiple libraries/frameworks _dummy_objects = {} _import_structure = { @@ -78,12 +77,16 @@ _import_structure["deprecated"].extend(["AudioDiffusionPipeline", "Mel"]) try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + if not ( + is_transformers_available() and is_torch_available() and is_note_seq_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 - _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) + _dummy_objects.update( + get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects) + ) else: _import_structure["deprecated"].extend( [ @@ -117,7 +120,11 @@ ] ) _import_structure["allegro"] = ["AllegroPipeline"] - _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] + _import_structure["amused"] = [ + "AmusedImg2ImgPipeline", + "AmusedInpaintPipeline", + "AmusedPipeline", + ] _import_structure["animatediff"] = [ "AnimateDiffPipeline", "AnimateDiffControlNetPipeline", @@ -264,7 +271,11 @@ ] ) _import_structure["latte"] = ["LattePipeline"] - _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline", "LTXConditionPipeline"] + _import_structure["ltx"] = [ + "LTXPipeline", + "LTXImageToVideoPipeline", + "LTXConditionPipeline", + ] _import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"] _import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"] _import_structure["marigold"].extend( @@ -280,7 +291,11 @@ _import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["pia"] = ["PIAPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] - _import_structure["sana"] = ["SanaPipeline", "SanaSprintPipeline"] + _import_structure["sana"] = [ + "SanaPipeline", + "SanaSprintPipeline", + "SanaControlNetPipeline", + ] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["stable_audio"] = [ @@ -314,7 +329,9 @@ "StableDiffusion3Img2ImgPipeline", "StableDiffusion3InpaintPipeline", ] - _import_structure["stable_diffusion_attend_and_excite"] = ["StableDiffusionAttendAndExcitePipeline"] + _import_structure["stable_diffusion_attend_and_excite"] = [ + "StableDiffusionAttendAndExcitePipeline" + ] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] _import_structure["stable_diffusion_sag"] = ["StableDiffusionSAGPipeline"] _import_structure["stable_diffusion_gligen"] = [ @@ -356,7 +373,11 @@ "WuerstchenDecoderPipeline", "WuerstchenPriorPipeline", ] - _import_structure["wan"] = ["WanPipeline", "WanImageToVideoPipeline", "WanVideoToVideoPipeline"] + _import_structure["wan"] = [ + "WanPipeline", + "WanImageToVideoPipeline", + "WanVideoToVideoPipeline", + ] try: if not is_onnx_available(): raise OptionalDependencyNotAvailable() @@ -367,12 +388,16 @@ else: _import_structure["onnx_utils"] = ["OnnxRuntimeModel"] try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + if not ( + is_torch_available() and is_transformers_available() and is_onnx_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) + _dummy_objects.update( + get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects) + ) else: _import_structure["stable_diffusion"].extend( [ @@ -385,14 +410,18 @@ ) try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_k_diffusion_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils import ( - dummy_torch_and_transformers_and_k_diffusion_objects, - ) + from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) + _dummy_objects.update( + get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects) + ) else: _import_structure["stable_diffusion_k_diffusion"] = [ "StableDiffusionKDiffusionPipeline", @@ -400,14 +429,18 @@ ] try: - if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_sentencepiece_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from ..utils import ( - dummy_torch_and_transformers_and_sentencepiece_objects, - ) + from ..utils import dummy_torch_and_transformers_and_sentencepiece_objects - _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects)) + _dummy_objects.update( + get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects) + ) else: _import_structure["kolors"] = [ "KolorsPipeline", @@ -462,7 +495,13 @@ from .dance_diffusion import DanceDiffusionPipeline from .ddim import DDIMPipeline from .ddpm import DDPMPipeline - from .deprecated import KarrasVePipeline, LDMPipeline, PNDMPipeline, RePaintPipeline, ScoreSdeVePipeline + from .deprecated import ( + KarrasVePipeline, + LDMPipeline, + PNDMPipeline, + RePaintPipeline, + ScoreSdeVePipeline, + ) from .dit import DiTPipeline from .latent_diffusion import LDMSuperResolutionPipeline from .pipeline_utils import ( @@ -525,10 +564,11 @@ StableDiffusionXLControlNetUnionInpaintPipeline, StableDiffusionXLControlNetUnionPipeline, ) - from .controlnet_hunyuandit import ( - HunyuanDiTControlNetPipeline, + from .controlnet_hunyuandit import HunyuanDiTControlNetPipeline + from .controlnet_sd3 import ( + StableDiffusion3ControlNetInpaintingPipeline, + StableDiffusion3ControlNetPipeline, ) - from .controlnet_sd3 import StableDiffusion3ControlNetInpaintingPipeline, StableDiffusion3ControlNetPipeline from .controlnet_xs import ( StableDiffusionControlNetXSPipeline, StableDiffusionXLControlNetXSPipeline, @@ -602,10 +642,7 @@ KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorPipeline, ) - from .kandinsky3 import ( - Kandinsky3Img2ImgPipeline, - Kandinsky3Pipeline, - ) + from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline from .latent_consistency_models import ( LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline, @@ -651,7 +688,7 @@ from .paint_by_example import PaintByExamplePipeline from .pia import PIAPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline - from .sana import SanaPipeline, SanaSprintPipeline + from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .stable_audio import StableAudioPipeline, StableAudioProjectionModel @@ -678,9 +715,14 @@ StableDiffusion3InpaintPipeline, StableDiffusion3Pipeline, ) - from .stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline + from .stable_diffusion_attend_and_excite import ( + StableDiffusionAttendAndExcitePipeline, + ) from .stable_diffusion_diffedit import StableDiffusionDiffEditPipeline - from .stable_diffusion_gligen import StableDiffusionGLIGENPipeline, StableDiffusionGLIGENTextImagePipeline + from .stable_diffusion_gligen import ( + StableDiffusionGLIGENPipeline, + StableDiffusionGLIGENTextImagePipeline, + ) from .stable_diffusion_ldm3d import StableDiffusionLDM3DPipeline from .stable_diffusion_panorama import StableDiffusionPanoramaPipeline from .stable_diffusion_safe import StableDiffusionPipelineSafe @@ -726,7 +768,11 @@ from .onnx_utils import OnnxRuntimeModel try: - if not (is_torch_available() and is_transformers_available() and is_onnx_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_onnx_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_onnx_objects import * @@ -740,7 +786,11 @@ ) try: - if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_k_diffusion_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * @@ -751,15 +801,16 @@ ) try: - if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): + if not ( + is_torch_available() + and is_transformers_available() + and is_sentencepiece_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_sentencepiece_objects import * else: - from .kolors import ( - KolorsImg2ImgPipeline, - KolorsPipeline, - ) + from .kolors import KolorsImg2ImgPipeline, KolorsPipeline try: if not is_flax_available(): @@ -781,21 +832,20 @@ FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionPipeline, ) - from .stable_diffusion_xl import ( - FlaxStableDiffusionXLPipeline, - ) + from .stable_diffusion_xl import FlaxStableDiffusionXLPipeline try: - if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): + if not ( + is_transformers_available() + and is_torch_available() + and is_note_seq_available() + ): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 else: - from .deprecated import ( - MidiProcessor, - SpectrogramDiffusionPipeline, - ) + from .deprecated import MidiProcessor, SpectrogramDiffusionPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/__init__.py b/src/diffusers/pipelines/sana/__init__.py index 1393b37e2d3a..c5814b2eb4da 100644 --- a/src/diffusers/pipelines/sana/__init__.py +++ b/src/diffusers/pipelines/sana/__init__.py @@ -9,7 +9,6 @@ is_transformers_available, ) - _dummy_objects = {} _import_structure = {} @@ -24,6 +23,7 @@ else: _import_structure["pipeline_sana"] = ["SanaPipeline"] _import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"] + _import_structure["pipeline_sana_controlnet"] = ["SanaControlNetPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -34,6 +34,7 @@ from ...utils.dummy_torch_and_transformers_objects import * else: from .pipeline_sana import SanaPipeline + from .pipeline_sana_controlnet import SanaControlNetPipeline from .pipeline_sana_sprint import SanaSprintPipeline else: import sys diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py new file mode 100644 index 000000000000..8a23486d6f80 --- /dev/null +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -0,0 +1,1236 @@ +# Copyright 2025 PixArt-Sigma Authors and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import html +import inspect +import re +import urllib.parse as ul +import warnings +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast + +from ...callbacks import MultiPipelineCallbacks, PipelineCallback +from ...image_processor import PipelineImageInput, PixArtImageProcessor +from ...loaders import SanaLoraLoaderMixin +from ...models import AutoencoderDC, SanaControlNetModel, SanaTransformer2DModel +from ...schedulers import DPMSolverMultistepScheduler +from ...utils import ( + BACKENDS_MAPPING, + USE_PEFT_BACKEND, + deprecate, + is_bs4_available, + is_ftfy_available, + is_torch_xla_available, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from ...utils.torch_utils import get_device, is_torch_version, randn_tensor +from ..pipeline_utils import DiffusionPipeline +from ..pixart_alpha.pipeline_pixart_alpha import ( + ASPECT_RATIO_512_BIN, + ASPECT_RATIO_1024_BIN, +) +from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN +from .pipeline_output import SanaPipelineOutput + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + + +ASPECT_RATIO_4096_BIN = { + "0.25": [2048.0, 8192.0], + "0.26": [2048.0, 7936.0], + "0.27": [2048.0, 7680.0], + "0.28": [2048.0, 7424.0], + "0.32": [2304.0, 7168.0], + "0.33": [2304.0, 6912.0], + "0.35": [2304.0, 6656.0], + "0.4": [2560.0, 6400.0], + "0.42": [2560.0, 6144.0], + "0.48": [2816.0, 5888.0], + "0.5": [2816.0, 5632.0], + "0.52": [2816.0, 5376.0], + "0.57": [3072.0, 5376.0], + "0.6": [3072.0, 5120.0], + "0.68": [3328.0, 4864.0], + "0.72": [3328.0, 4608.0], + "0.78": [3584.0, 4608.0], + "0.82": [3584.0, 4352.0], + "0.88": [3840.0, 4352.0], + "0.94": [3840.0, 4096.0], + "1.0": [4096.0, 4096.0], + "1.07": [4096.0, 3840.0], + "1.13": [4352.0, 3840.0], + "1.21": [4352.0, 3584.0], + "1.29": [4608.0, 3584.0], + "1.38": [4608.0, 3328.0], + "1.46": [4864.0, 3328.0], + "1.67": [5120.0, 3072.0], + "1.75": [5376.0, 3072.0], + "2.0": [5632.0, 2816.0], + "2.09": [5888.0, 2816.0], + "2.4": [6144.0, 2560.0], + "2.5": [6400.0, 2560.0], + "2.89": [6656.0, 2304.0], + "3.0": [6912.0, 2304.0], + "3.11": [7168.0, 2304.0], + "3.62": [7424.0, 2048.0], + "3.75": [7680.0, 2048.0], + "3.88": [7936.0, 2048.0], + "4.0": [8192.0, 2048.0], +} + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import SanaControlNetPipeline + >>> from diffusers.utils import load_image + + >>> pipe = SanaControlNetPipeline.from_pretrained( + ... "ishan24/Sana_600M_1024px_ControlNetPlus_diffusers", + ... variant="fp16", + ... torch_dtype={"default": torch.bfloat16, "controlnet": torch.float16, "transformer": torch.float16}, + ... device_map="balanced", + ... ) + >>> cond_image = load_image( + ... "https://huggingface.co/ishan24/Sana_600M_1024px_ControlNet_diffusers/resolve/main/hed_example.png" + ... ) + >>> prompt = 'a cat with a neon sign that says "Sana"' + >>> image = pipe( + ... prompt, + ... control_image=cond_image, + ... ).images[0] + >>> image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class SanaControlNetPipeline(DiffusionPipeline, SanaLoraLoaderMixin): + r""" + Pipeline for text-to-image generation using [Sana](https://huggingface.co/papers/2410.10629). + """ + + # fmt: off + bad_punct_regex = re.compile(r"[" + "#®•©™&@·º½¾¿¡§~" + r"\)" + r"\(" + r"\]" + r"\[" + r"\}" + r"\{" + r"\|" + "\\" + r"\/" + r"\*" + r"]{1,}") + # fmt: on + + model_cpu_offload_seq = "text_encoder->controlnet->transformer->vae" + _callback_tensor_inputs = [ + "latents", + "control_image", + "prompt_embeds", + "negative_prompt_embeds", + ] + + def __init__( + self, + tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast], + text_encoder: Gemma2PreTrainedModel, + vae: AutoencoderDC, + transformer: SanaTransformer2DModel, + controlnet: SanaControlNetModel, + scheduler: DPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + controlnet=controlnet, + scheduler=scheduler, + ) + + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.encoder_block_out_channels) - 1) + if hasattr(self, "vae") and self.vae is not None + else 32 + ) + self.image_processor = PixArtImageProcessor( + vae_scale_factor=self.vae_scale_factor + ) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." + deprecate( + "enable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." + deprecate( + "disable_vae_slicing", + "0.40.0", + depr_message, + ) + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." + deprecate( + "enable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." + deprecate( + "disable_vae_tiling", + "0.40.0", + depr_message, + ) + self.vae.disable_tiling() + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds + def _get_gemma_prompt_embeds( + self, + prompt: Union[str, List[str]], + device: torch.device, + dtype: torch.dtype, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + prompt = [prompt] if isinstance(prompt, str) else prompt + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + + # prepare complex human instruction + if not complex_human_instruction: + max_length_all = max_sequence_length + else: + chi_prompt = "\n".join(complex_human_instruction) + prompt = [chi_prompt + p for p in prompt] + num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt)) + max_length_all = num_chi_prompt_tokens + max_sequence_length - 2 + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length_all, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device) + + return prompt_embeds, prompt_attention_mask + + # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + clean_caption: bool = False, + max_sequence_length: int = 300, + complex_human_instruction: Optional[List[str]] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + PixArt-Alpha, this should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string. + clean_caption (`bool`, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt. + complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`): + If `complex_human_instruction` is not empty, the function will use the complex Human instruction for + the prompt. + """ + + if device is None: + device = self._execution_device + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + else: + dtype = None + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if self.text_encoder is not None and USE_PEFT_BACKEND: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if getattr(self, "tokenizer", None) is not None: + self.tokenizer.padding_side = "right" + + # See Section 3.1. of the paper. + max_length = max_sequence_length + select_index = [0] + list(range(-max_length + 1, 0)) + + if prompt_embeds is None: + prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds( + prompt=prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + ) + + prompt_embeds = prompt_embeds[:, select_index] + prompt_attention_mask = prompt_attention_mask[:, select_index] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1) + prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = ( + [negative_prompt] * batch_size + if isinstance(negative_prompt, str) + else negative_prompt + ) + negative_prompt_embeds, negative_prompt_attention_mask = ( + self._get_gemma_prompt_embeds( + prompt=negative_prompt, + device=device, + dtype=dtype, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=False, + ) + ) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + bs_embed, -1 + ) + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat( + num_images_per_prompt, 1 + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + if self.text_encoder is not None: + if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_on_step_end_tensor_inputs=None, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 32 != 0 or width % 32 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 32 but are {height} and {width}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs + for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError( + "Must provide `prompt_attention_mask` when specifying `prompt_embeds`." + ) + + if ( + negative_prompt_embeds is not None + and negative_prompt_attention_mask is None + ): + raise ValueError( + "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warning( + BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`") + ) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warning( + BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`") + ) + logger.warning("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip addresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub( + r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption + ) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub( + self.bad_punct_regex, r" ", caption + ) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub( + r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption + ) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub( + r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption + ) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + if latents is not None: + return latents.to(device=device, dtype=dtype) + + shape = ( + batch_size, + num_channels_latents, + int(height) // self.vae_scale_factor, + int(width) // self.vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + return latents + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def attention_kwargs(self): + return self._attention_kwargs + + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1.0 + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + sigmas: List[float] = None, + guidance_scale: float = 4.5, + control_image: PipelineImageInput = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + num_images_per_prompt: Optional[int] = 1, + height: int = 1024, + width: int = 1024, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + prompt_attention_mask: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + clean_caption: bool = False, + use_resolution_binning: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + max_sequence_length: int = 300, + complex_human_instruction: List[str] = [ + "Given a user prompt, generate an 'Enhanced prompt' that provides detailed visual descriptions suitable for image generation. Evaluate the level of detail in the user prompt:", + "- If the prompt is simple, focus on adding specifics about colors, shapes, sizes, textures, and spatial relationships to create vivid and concrete scenes.", + "- If the prompt is already detailed, refine and enhance the existing details slightly without overcomplicating.", + "Here are examples of how to transform or refine prompts:", + "- User Prompt: A cat sleeping -> Enhanced: A small, fluffy white cat curled up in a round shape, sleeping peacefully on a warm sunny windowsill, surrounded by pots of blooming red flowers.", + "- User Prompt: A busy city street -> Enhanced: A bustling city street scene at dusk, featuring glowing street lamps, a diverse crowd of people in colorful clothing, and a double-decker bus passing by towering glass skyscrapers.", + "Please generate only the enhanced description for the prompt below and avoid including any additional commentary or evaluations:", + "User Prompt: ", + ], + ) -> Union[SanaPipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 20): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + sigmas (`List[float]`, *optional*): + Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in + their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed + will be used. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion + Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. + of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting + `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to + the text `prompt`, usually at the expense of lower image quality. + control_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.Tensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted + as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or + width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`, + images must be passed as a list such that each element of the list can be correctly batched for input + to a single ControlNet. + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set + the corresponding scale as a list. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only + applies to [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will be generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.Tensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + attention_kwargs: + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + max_sequence_length (`int` defaults to `300`): + Maximum sequence length to use with the `prompt`. + complex_human_instruction (`List[str]`, *optional*): + Instructions for complex human attention: + https://github.com/NVlabs/Sana/blob/main/configs/sana_app_config/Sana_1600M_app.yaml#L55. + + Examples: + + Returns: + [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images + """ + + if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): + callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs + + # 1. Check inputs. Raise error if not correct + if use_resolution_binning: + if self.transformer.config.sample_size == 128: + aspect_ratio_bin = ASPECT_RATIO_4096_BIN + elif self.transformer.config.sample_size == 64: + aspect_ratio_bin = ASPECT_RATIO_2048_BIN + elif self.transformer.config.sample_size == 32: + aspect_ratio_bin = ASPECT_RATIO_1024_BIN + elif self.transformer.config.sample_size == 16: + aspect_ratio_bin = ASPECT_RATIO_512_BIN + else: + raise ValueError("Invalid sample size") + orig_height, orig_width = height, width + height, width = self.image_processor.classify_height_width_bin( + height, width, ratios=aspect_ratio_bin + ) + + self.check_inputs( + prompt, + height, + width, + callback_on_step_end_tensor_inputs, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + self._guidance_scale = guidance_scale + self._attention_kwargs = attention_kwargs + self._interrupt = False + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + lora_scale = ( + self.attention_kwargs.get("scale", None) + if self.attention_kwargs is not None + else None + ) + + # 3. Encode input prompt + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + self.do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + max_sequence_length=max_sequence_length, + complex_human_instruction=complex_human_instruction, + lora_scale=lora_scale, + ) + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_attention_mask = torch.cat( + [negative_prompt_attention_mask, prompt_attention_mask], dim=0 + ) + + # 4. Prepare control image + if isinstance(self.controlnet, SanaControlNetModel): + control_image = self.prepare_image( + image=control_image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=self.vae.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=False, + ) + height, width = control_image.shape[-2:] + + control_image = self.vae.encode(control_image).latent + control_image = control_image * self.vae.config.scaling_factor + else: + raise ValueError("`controlnet` must be of type `SanaControlNetModel`.") + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, num_inference_steps, device, timesteps, sigmas + ) + + # 6. Prepare latents. + latent_channels = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + latent_channels, + height, + width, + torch.float32, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + self._num_timesteps = len(timesteps) + + controlnet_dtype = self.controlnet.dtype + transformer_dtype = self.transformer.dtype + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep = t.expand(latent_model_input.shape[0]) + + # controlnet(s) inference + controlnet_block_samples = self.controlnet( + latent_model_input.to(dtype=controlnet_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=controlnet_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_cond=control_image, + conditioning_scale=controlnet_conditioning_scale, + )[0] + + # predict noise model_output + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_block_samples=tuple( + t.to(dtype=transformer_dtype) for t in controlnet_block_samples + ), + )[0] + noise_pred = noise_pred.float() + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # learned sigma + if self.transformer.config.out_channels // 2 == latent_channels: + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop( + "negative_prompt_embeds", negative_prompt_embeds + ) + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if XLA_AVAILABLE: + xm.mark_step() + + if output_type == "latent": + image = latents + else: + latents = latents.to(self.vae.dtype) + torch_accelerator_module = getattr(torch, get_device(), torch.cuda) + oom_error = ( + torch.OutOfMemoryError + if is_torch_version(">=", "2.5.0") + else torch_accelerator_module.OutOfMemoryError + ) + try: + image = self.vae.decode( + latents / self.vae.config.scaling_factor, return_dict=False + )[0] + except oom_error as e: + warnings.warn( + f"{e}. \n" + f"Try to use VAE tiling for large images. For example: \n" + f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)" + ) + if use_resolution_binning: + image = self.image_processor.resize_and_crop_tensor( + image, orig_width, orig_height + ) + + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return SanaPipelineOutput(images=image) From fd0add6e6cd52c49f62efed389e8d969deefa69e Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 30 Sep 2025 21:43:31 +0000 Subject: [PATCH 3/6] add torch utils --- src/diffusers/utils/torch_utils.py | 64 ++++++++++++++++++++++++------ 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/src/diffusers/utils/torch_utils.py b/src/diffusers/utils/torch_utils.py index 3c8911773e39..06be5cb961ac 100644 --- a/src/diffusers/utils/torch_utils.py +++ b/src/diffusers/utils/torch_utils.py @@ -14,13 +14,12 @@ """ PyTorch utilities: Utilities related to PyTorch """ - +import functools from typing import List, Optional, Tuple, Union from . import logging from .import_utils import is_torch_available, is_torch_version - if is_torch_available(): import torch from torch.fft import fftn, fftshift, ifftn, ifftshift @@ -54,7 +53,11 @@ def randn_tensor( device = device or torch.device("cpu") if generator is not None: - gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + gen_device_type = ( + generator.device.type + if not isinstance(generator, list) + else generator[0].device.type + ) if gen_device_type != device.type and gen_device_type == "cpu": rand_device = "cpu" if device != "mps": @@ -64,7 +67,9 @@ def randn_tensor( f" slighly speed up this function by passing a generator that was created on the {device} device." ) elif gen_device_type != device.type and gen_device_type == "cuda": - raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + raise ValueError( + f"Cannot generate a {device} tensor from a generator of type {gen_device_type}." + ) # make sure generator list of length 1 is treated like a non-list if isinstance(generator, list) and len(generator) == 1: @@ -73,12 +78,20 @@ def randn_tensor( if isinstance(generator, list): shape = (1,) + shape[1:] latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + torch.randn( + shape, + generator=generator[i], + device=rand_device, + dtype=dtype, + layout=layout, + ) for i in range(batch_size) ] latents = torch.cat(latents, dim=0).to(device) else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + latents = torch.randn( + shape, generator=generator, device=rand_device, dtype=dtype, layout=layout + ).to(device) return latents @@ -114,7 +127,9 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T mask = torch.ones((B, C, H, W), device=x.device) crow, ccol = H // 2, W // 2 - mask[..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold] = scale + mask[ + ..., crow - threshold : crow + threshold, ccol - threshold : ccol + threshold + ] = scale x_freq = x_freq * mask # IFFT @@ -125,7 +140,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T def apply_freeu( - resolution_idx: int, hidden_states: "torch.Tensor", res_hidden_states: "torch.Tensor", **freeu_kwargs + resolution_idx: int, + hidden_states: "torch.Tensor", + res_hidden_states: "torch.Tensor", + **freeu_kwargs, ) -> Tuple["torch.Tensor", "torch.Tensor"]: """Applies the FreeU mechanism as introduced in https: //arxiv.org/abs/2309.11497. Adapted from the official code repository: https://github.com/ChenyangSi/FreeU. @@ -141,12 +159,20 @@ def apply_freeu( """ if resolution_idx == 0: num_half_channels = hidden_states.shape[1] // 2 - hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] - res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s1"]) + hidden_states[:, :num_half_channels] = ( + hidden_states[:, :num_half_channels] * freeu_kwargs["b1"] + ) + res_hidden_states = fourier_filter( + res_hidden_states, threshold=1, scale=freeu_kwargs["s1"] + ) if resolution_idx == 1: num_half_channels = hidden_states.shape[1] // 2 - hidden_states[:, :num_half_channels] = hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] - res_hidden_states = fourier_filter(res_hidden_states, threshold=1, scale=freeu_kwargs["s2"]) + hidden_states[:, :num_half_channels] = ( + hidden_states[:, :num_half_channels] * freeu_kwargs["b2"] + ) + res_hidden_states = fourier_filter( + res_hidden_states, threshold=1, scale=freeu_kwargs["s2"] + ) return hidden_states, res_hidden_states @@ -159,3 +185,17 @@ def get_torch_cuda_device_capability(): return float(compute_capability) else: return None + + +@functools.lru_cache +def get_device(): + if torch.cuda.is_available(): + return "cuda" + elif is_torch_npu_available(): + return "npu" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + elif torch.backends.mps.is_available(): + return "mps" + else: + return "cpu" From 9f94f5687cf081f42d6b046027fc1f21ae9bac45 Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 30 Sep 2025 21:44:03 +0000 Subject: [PATCH 4/6] remove redundant file --- .../models/transformers/file-changes.diff | 929 ------------------ 1 file changed, 929 deletions(-) delete mode 100644 src/diffusers/models/transformers/file-changes.diff diff --git a/src/diffusers/models/transformers/file-changes.diff b/src/diffusers/models/transformers/file-changes.diff deleted file mode 100644 index 584bb5170bb7..000000000000 --- a/src/diffusers/models/transformers/file-changes.diff +++ /dev/null @@ -1,929 +0,0 @@ -diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py -index 7ab371a1a..b64920c37 100644 ---- a/src/diffusers/models/transformers/transformer_flux.py -+++ b/src/diffusers/models/transformers/transformer_flux.py -@@ -1,4 +1,4 @@ --# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. -+# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. - # - # Licensed under the Apache License, Version 2.0 (the "License"); - # you may not use this file except in compliance with the License. -@@ -12,340 +12,65 @@ - # See the License for the specific language governing permissions and - # limitations under the License. - --import inspect --from typing import Any, Dict, List, Optional, Tuple, Union -+ -+from typing import Any, Dict, Optional, Tuple, Union - - import numpy as np - import torch - import torch.nn as nn --import torch.nn.functional as F - - from ...configuration_utils import ConfigMixin, register_to_config --from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin --from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers -+from ...loaders import ( -+ FluxTransformer2DLoadersMixin, -+ FromOriginalModelMixin, -+ PeftAdapterMixin, -+) -+from ...models.attention import FeedForward -+from ...models.attention_processor import ( -+ Attention, -+ AttentionProcessor, -+ FluxAttnProcessor2_0, -+ FluxAttnProcessor2_0_NPU, -+ FusedFluxAttnProcessor2_0, -+) -+from ...models.modeling_utils import ModelMixin -+from ...models.normalization import ( -+ AdaLayerNormContinuous, -+ AdaLayerNormZero, -+ AdaLayerNormZeroSingle, -+) -+from ...utils import ( -+ USE_PEFT_BACKEND, -+ deprecate, -+ logging, -+ scale_lora_layers, -+ unscale_lora_layers, -+) -+from ...utils.import_utils import is_torch_npu_available - from ...utils.torch_utils import maybe_allow_in_graph --from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward --from ..attention_dispatch import dispatch_attention_fn - from ..cache_utils import CacheMixin - from ..embeddings import ( - CombinedTimestepGuidanceTextProjEmbeddings, - CombinedTimestepTextProjEmbeddings, -- apply_rotary_emb, -- get_1d_rotary_pos_embed, -+ FluxPosEmbed, -+ TimestepEmbedding, -+ Timesteps, - ) - from ..modeling_outputs import Transformer2DModelOutput --from ..modeling_utils import ModelMixin --from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -- - - logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - --def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): -- query = attn.to_q(hidden_states) -- key = attn.to_k(hidden_states) -- value = attn.to_v(hidden_states) -- -- encoder_query = encoder_key = encoder_value = None -- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None: -- encoder_query = attn.add_q_proj(encoder_hidden_states) -- encoder_key = attn.add_k_proj(encoder_hidden_states) -- encoder_value = attn.add_v_proj(encoder_hidden_states) -- -- return query, key, value, encoder_query, encoder_key, encoder_value -- -- --def _get_fused_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): -- query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1) -- -- encoder_query = encoder_key = encoder_value = (None,) -- if encoder_hidden_states is not None and hasattr(attn, "to_added_qkv"): -- encoder_query, encoder_key, encoder_value = attn.to_added_qkv(encoder_hidden_states).chunk(3, dim=-1) -- -- return query, key, value, encoder_query, encoder_key, encoder_value -- -- --def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None): -- if attn.fused_projections: -- return _get_fused_projections(attn, hidden_states, encoder_hidden_states) -- return _get_projections(attn, hidden_states, encoder_hidden_states) -- -- --class FluxAttnProcessor: -- _attention_backend = None -- -- def __init__(self): -- if not hasattr(F, "scaled_dot_product_attention"): -- raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.") -- -- def __call__( -- self, -- attn: "FluxAttention", -- hidden_states: torch.Tensor, -- encoder_hidden_states: torch.Tensor = None, -- attention_mask: Optional[torch.Tensor] = None, -- image_rotary_emb: Optional[torch.Tensor] = None, -- ) -> torch.Tensor: -- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( -- attn, hidden_states, encoder_hidden_states -- ) -- -- query = query.unflatten(-1, (attn.heads, -1)) -- key = key.unflatten(-1, (attn.heads, -1)) -- value = value.unflatten(-1, (attn.heads, -1)) -- -- query = attn.norm_q(query) -- key = attn.norm_k(key) -- -- if attn.added_kv_proj_dim is not None: -- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) -- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) -- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) -- -- encoder_query = attn.norm_added_q(encoder_query) -- encoder_key = attn.norm_added_k(encoder_key) -- -- query = torch.cat([encoder_query, query], dim=1) -- key = torch.cat([encoder_key, key], dim=1) -- value = torch.cat([encoder_value, value], dim=1) -- -- if image_rotary_emb is not None: -- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) -- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) -- -- hidden_states = dispatch_attention_fn( -- query, key, value, attn_mask=attention_mask, backend=self._attention_backend -- ) -- hidden_states = hidden_states.flatten(2, 3) -- hidden_states = hidden_states.to(query.dtype) -- -- if encoder_hidden_states is not None: -- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( -- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 -- ) -- hidden_states = attn.to_out[0](hidden_states) -- hidden_states = attn.to_out[1](hidden_states) -- encoder_hidden_states = attn.to_add_out(encoder_hidden_states) -- -- return hidden_states, encoder_hidden_states -- else: -- return hidden_states -- -- --class FluxIPAdapterAttnProcessor(torch.nn.Module): -- """Flux Attention processor for IP-Adapter.""" -- -- _attention_backend = None -- -- def __init__( -- self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None -- ): -- super().__init__() -- -- if not hasattr(F, "scaled_dot_product_attention"): -- raise ImportError( -- f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." -- ) -- -- self.hidden_size = hidden_size -- self.cross_attention_dim = cross_attention_dim -- -- if not isinstance(num_tokens, (tuple, list)): -- num_tokens = [num_tokens] -- -- if not isinstance(scale, list): -- scale = [scale] * len(num_tokens) -- if len(scale) != len(num_tokens): -- raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.") -- self.scale = scale -- -- self.to_k_ip = nn.ModuleList( -- [ -- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) -- for _ in range(len(num_tokens)) -- ] -- ) -- self.to_v_ip = nn.ModuleList( -- [ -- nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype) -- for _ in range(len(num_tokens)) -- ] -- ) -- -- def __call__( -- self, -- attn: "FluxAttention", -- hidden_states: torch.Tensor, -- encoder_hidden_states: torch.Tensor = None, -- attention_mask: Optional[torch.Tensor] = None, -- image_rotary_emb: Optional[torch.Tensor] = None, -- ip_hidden_states: Optional[List[torch.Tensor]] = None, -- ip_adapter_masks: Optional[torch.Tensor] = None, -- ) -> torch.Tensor: -- batch_size = hidden_states.shape[0] -- -- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections( -- attn, hidden_states, encoder_hidden_states -- ) -- -- query = query.unflatten(-1, (attn.heads, -1)) -- key = key.unflatten(-1, (attn.heads, -1)) -- value = value.unflatten(-1, (attn.heads, -1)) -- -- query = attn.norm_q(query) -- key = attn.norm_k(key) -- ip_query = query -- -- if encoder_hidden_states is not None: -- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1)) -- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1)) -- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1)) -- -- encoder_query = attn.norm_added_q(encoder_query) -- encoder_key = attn.norm_added_k(encoder_key) -- -- query = torch.cat([encoder_query, query], dim=1) -- key = torch.cat([encoder_key, key], dim=1) -- value = torch.cat([encoder_value, value], dim=1) -- -- if image_rotary_emb is not None: -- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) -- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) -- -- hidden_states = dispatch_attention_fn( -- query, -- key, -- value, -- attn_mask=attention_mask, -- dropout_p=0.0, -- is_causal=False, -- backend=self._attention_backend, -- ) -- hidden_states = hidden_states.flatten(2, 3) -- hidden_states = hidden_states.to(query.dtype) -- -- if encoder_hidden_states is not None: -- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes( -- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 -- ) -- hidden_states = attn.to_out[0](hidden_states) -- hidden_states = attn.to_out[1](hidden_states) -- encoder_hidden_states = attn.to_add_out(encoder_hidden_states) -- -- # IP-adapter -- ip_attn_output = torch.zeros_like(hidden_states) -- -- for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip( -- ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip -- ): -- ip_key = to_k_ip(current_ip_hidden_states) -- ip_value = to_v_ip(current_ip_hidden_states) -- -- ip_key = ip_key.view(batch_size, -1, attn.heads, attn.head_dim) -- ip_value = ip_value.view(batch_size, -1, attn.heads, attn.head_dim) -- -- current_ip_hidden_states = dispatch_attention_fn( -- ip_query, -- ip_key, -- ip_value, -- attn_mask=None, -- dropout_p=0.0, -- is_causal=False, -- backend=self._attention_backend, -- ) -- current_ip_hidden_states = current_ip_hidden_states.reshape(batch_size, -1, attn.heads * attn.head_dim) -- current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype) -- ip_attn_output += scale * current_ip_hidden_states -- -- return hidden_states, encoder_hidden_states, ip_attn_output -- else: -- return hidden_states -- -- --class FluxAttention(torch.nn.Module, AttentionModuleMixin): -- _default_processor_cls = FluxAttnProcessor -- _available_processors = [ -- FluxAttnProcessor, -- FluxIPAdapterAttnProcessor, -- ] -- -+@maybe_allow_in_graph -+class FluxSingleTransformerBlock(nn.Module): - def __init__( - self, -- query_dim: int, -- heads: int = 8, -- dim_head: int = 64, -- dropout: float = 0.0, -- bias: bool = False, -- added_kv_proj_dim: Optional[int] = None, -- added_proj_bias: Optional[bool] = True, -- out_bias: bool = True, -- eps: float = 1e-5, -- out_dim: int = None, -- context_pre_only: Optional[bool] = None, -- pre_only: bool = False, -- elementwise_affine: bool = True, -- processor=None, -+ dim: int, -+ num_attention_heads: int, -+ attention_head_dim: int, -+ mlp_ratio: float = 4.0, - ): - super().__init__() -- -- self.head_dim = dim_head -- self.inner_dim = out_dim if out_dim is not None else dim_head * heads -- self.query_dim = query_dim -- self.use_bias = bias -- self.dropout = dropout -- self.out_dim = out_dim if out_dim is not None else query_dim -- self.context_pre_only = context_pre_only -- self.pre_only = pre_only -- self.heads = out_dim // dim_head if out_dim is not None else heads -- self.added_kv_proj_dim = added_kv_proj_dim -- self.added_proj_bias = added_proj_bias -- -- self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) -- self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine) -- self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) -- self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) -- self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias) -- -- if not self.pre_only: -- self.to_out = torch.nn.ModuleList([]) -- self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)) -- self.to_out.append(torch.nn.Dropout(dropout)) -- -- if added_kv_proj_dim is not None: -- self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps) -- self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps) -- self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) -- self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) -- self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) -- self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias) -- -- if processor is None: -- processor = self._default_processor_cls() -- self.set_processor(processor) -- -- def forward( -- self, -- hidden_states: torch.Tensor, -- encoder_hidden_states: Optional[torch.Tensor] = None, -- attention_mask: Optional[torch.Tensor] = None, -- image_rotary_emb: Optional[torch.Tensor] = None, -- **kwargs, -- ) -> torch.Tensor: -- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) -- quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"} -- unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters] -- if len(unused_kwargs) > 0: -- logger.warning( -- f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored." -- ) -- kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters} -- return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs) -- -- --@maybe_allow_in_graph --class FluxSingleTransformerBlock(nn.Module): -- def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): -- super().__init__() - self.mlp_hidden_dim = int(dim * mlp_ratio) - - self.norm = AdaLayerNormZeroSingle(dim) -@@ -353,13 +78,25 @@ class FluxSingleTransformerBlock(nn.Module): - self.act_mlp = nn.GELU(approximate="tanh") - self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) - -- self.attn = FluxAttention( -+ if is_torch_npu_available(): -+ deprecation_message = ( -+ "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors " -+ "should be set explicitly using the `set_attn_processor` method." -+ ) -+ deprecate("npu_processor", "0.34.0", deprecation_message) -+ processor = FluxAttnProcessor2_0_NPU() -+ else: -+ processor = FluxAttnProcessor2_0() -+ -+ self.attn = Attention( - query_dim=dim, -+ cross_attention_dim=None, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - bias=True, -- processor=FluxAttnProcessor(), -+ processor=processor, -+ qk_norm="rms_norm", - eps=1e-6, - pre_only=True, - ) -@@ -367,14 +104,11 @@ class FluxSingleTransformerBlock(nn.Module): - def forward( - self, - hidden_states: torch.Tensor, -- encoder_hidden_states: torch.Tensor, - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, -- ) -> Tuple[torch.Tensor, torch.Tensor]: -- text_seq_len = encoder_hidden_states.shape[1] -- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) -- -+ attention_mask: Optional[torch.Tensor] = None, -+ ) -> torch.Tensor: - residual = hidden_states - norm_hidden_states, gate = self.norm(hidden_states, emb=temb) - mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) -@@ -382,6 +116,7 @@ class FluxSingleTransformerBlock(nn.Module): - attn_output = self.attn( - hidden_states=norm_hidden_states, - image_rotary_emb=image_rotary_emb, -+ attention_mask=attention_mask, - **joint_attention_kwargs, - ) - -@@ -392,29 +127,35 @@ class FluxSingleTransformerBlock(nn.Module): - if hidden_states.dtype == torch.float16: - hidden_states = hidden_states.clip(-65504, 65504) - -- encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:] -- return encoder_hidden_states, hidden_states -+ return hidden_states - - - @maybe_allow_in_graph - class FluxTransformerBlock(nn.Module): - def __init__( -- self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 -+ self, -+ dim: int, -+ num_attention_heads: int, -+ attention_head_dim: int, -+ qk_norm: str = "rms_norm", -+ eps: float = 1e-6, - ): - super().__init__() - - self.norm1 = AdaLayerNormZero(dim) - self.norm1_context = AdaLayerNormZero(dim) - -- self.attn = FluxAttention( -+ self.attn = Attention( - query_dim=dim, -+ cross_attention_dim=None, - added_kv_proj_dim=dim, - dim_head=attention_head_dim, - heads=num_attention_heads, - out_dim=dim, - context_pre_only=False, - bias=True, -- processor=FluxAttnProcessor(), -+ processor=FluxAttnProcessor2_0(), -+ qk_norm=qk_norm, - eps=eps, - ) - -@@ -422,7 +163,9 @@ class FluxTransformerBlock(nn.Module): - self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") - - self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) -- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") -+ self.ff_context = FeedForward( -+ dim=dim, dim_out=dim, activation_fn="gelu-approximate" -+ ) - - def forward( - self, -@@ -431,19 +174,22 @@ class FluxTransformerBlock(nn.Module): - temb: torch.Tensor, - image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, -+ attention_mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: -- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) -+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( -+ hidden_states, emb=temb -+ ) - -- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( -- encoder_hidden_states, emb=temb -+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( -+ self.norm1_context(encoder_hidden_states, emb=temb) - ) - joint_attention_kwargs = joint_attention_kwargs or {} -- - # Attention. - attention_outputs = self.attn( - hidden_states=norm_hidden_states, - encoder_hidden_states=norm_encoder_hidden_states, - image_rotary_emb=image_rotary_emb, -+ attention_mask=attention_mask, - **joint_attention_kwargs, - ) - -@@ -457,7 +203,9 @@ class FluxTransformerBlock(nn.Module): - hidden_states = hidden_states + attn_output - - norm_hidden_states = self.norm2(hidden_states) -- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] -+ norm_hidden_states = ( -+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] -+ ) - - ff_output = self.ff(norm_hidden_states) - ff_output = gate_mlp.unsqueeze(1) * ff_output -@@ -467,51 +215,26 @@ class FluxTransformerBlock(nn.Module): - hidden_states = hidden_states + ip_attn_output - - # Process attention outputs for the `encoder_hidden_states`. -+ - context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output - encoder_hidden_states = encoder_hidden_states + context_attn_output - - norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) -- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] -+ norm_encoder_hidden_states = ( -+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) -+ + c_shift_mlp[:, None] -+ ) - - context_ff_output = self.ff_context(norm_encoder_hidden_states) -- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output -+ encoder_hidden_states = ( -+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output -+ ) - if encoder_hidden_states.dtype == torch.float16: - encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) - - return encoder_hidden_states, hidden_states - - --class FluxPosEmbed(nn.Module): -- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 -- def __init__(self, theta: int, axes_dim: List[int]): -- super().__init__() -- self.theta = theta -- self.axes_dim = axes_dim -- -- def forward(self, ids: torch.Tensor) -> torch.Tensor: -- n_axes = ids.shape[-1] -- cos_out = [] -- sin_out = [] -- pos = ids.float() -- is_mps = ids.device.type == "mps" -- is_npu = ids.device.type == "npu" -- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 -- for i in range(n_axes): -- cos, sin = get_1d_rotary_pos_embed( -- self.axes_dim[i], -- pos[:, i], -- theta=self.theta, -- repeat_interleave_real=True, -- use_real=True, -- freqs_dtype=freqs_dtype, -- ) -- cos_out.append(cos) -- sin_out.append(sin) -- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) -- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) -- return freqs_cos, freqs_sin -- -- - class FluxTransformer2DModel( - ModelMixin, - ConfigMixin, -@@ -519,7 +242,6 @@ class FluxTransformer2DModel( - FromOriginalModelMixin, - FluxTransformer2DLoadersMixin, - CacheMixin, -- AttentionMixin, - ): - """ - The Transformer model introduced in Flux. -@@ -555,7 +277,6 @@ class FluxTransformer2DModel( - _supports_gradient_checkpointing = True - _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - _skip_layerwise_casting_patterns = ["pos_embed", "norm"] -- _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] - - @register_to_config - def __init__( -@@ -570,7 +291,8 @@ class FluxTransformer2DModel( - joint_attention_dim: int = 4096, - pooled_projection_dim: int = 768, - guidance_embeds: bool = False, -- axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), -+ additional_timestep_embeds: bool = False, -+ axes_dims_rope: Tuple[int] = (16, 56, 56), - ): - super().__init__() - self.out_channels = out_channels or in_channels -@@ -579,12 +301,23 @@ class FluxTransformer2DModel( - self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) - - text_time_guidance_cls = ( -- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings -+ CombinedTimestepGuidanceTextProjEmbeddings -+ if guidance_embeds -+ else CombinedTimestepTextProjEmbeddings - ) -+ - self.time_text_embed = text_time_guidance_cls( - embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim - ) - -+ if additional_timestep_embeds: -+ self.additional_time_proj = Timesteps( -+ num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 -+ ) -+ self.additional_timestep_embedder = TimestepEmbedding( -+ in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False -+ ) -+ - self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) - self.x_embedder = nn.Linear(in_channels, self.inner_dim) - -@@ -610,11 +343,123 @@ class FluxTransformer2DModel( - ] - ) - -- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) -- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) -+ self.norm_out = AdaLayerNormContinuous( -+ self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 -+ ) -+ self.proj_out = nn.Linear( -+ self.inner_dim, patch_size * patch_size * self.out_channels, bias=True -+ ) - - self.gradient_checkpointing = False - -+ @property -+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors -+ def attn_processors(self) -> Dict[str, AttentionProcessor]: -+ r""" -+ Returns: -+ `dict` of attention processors: A dictionary containing all attention processors used in the model with -+ indexed by its weight name. -+ """ -+ # set recursively -+ processors = {} -+ -+ def fn_recursive_add_processors( -+ name: str, -+ module: torch.nn.Module, -+ processors: Dict[str, AttentionProcessor], -+ ): -+ if hasattr(module, "get_processor"): -+ processors[f"{name}.processor"] = module.get_processor() -+ -+ for sub_name, child in module.named_children(): -+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) -+ -+ return processors -+ -+ for name, module in self.named_children(): -+ fn_recursive_add_processors(name, module, processors) -+ -+ return processors -+ -+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor -+ def set_attn_processor( -+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] -+ ): -+ r""" -+ Sets the attention processor to use to compute attention. -+ -+ Parameters: -+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): -+ The instantiated processor class or a dictionary of processor classes that will be set as the processor -+ for **all** `Attention` layers. -+ -+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention -+ processor. This is strongly recommended when setting trainable attention processors. -+ -+ """ -+ count = len(self.attn_processors.keys()) -+ -+ if isinstance(processor, dict) and len(processor) != count: -+ raise ValueError( -+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" -+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes." -+ ) -+ -+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): -+ if hasattr(module, "set_processor"): -+ if not isinstance(processor, dict): -+ module.set_processor(processor) -+ else: -+ module.set_processor(processor.pop(f"{name}.processor")) -+ -+ for sub_name, child in module.named_children(): -+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) -+ -+ for name, module in self.named_children(): -+ fn_recursive_attn_processor(name, module, processor) -+ -+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 -+ def fuse_qkv_projections(self): -+ """ -+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) -+ are fused. For cross-attention modules, key and value projection matrices are fused. -+ -+ -+ -+ This API is 🧪 experimental. -+ -+ -+ """ -+ self.original_attn_processors = None -+ -+ for _, attn_processor in self.attn_processors.items(): -+ if "Added" in str(attn_processor.__class__.__name__): -+ raise ValueError( -+ "`fuse_qkv_projections()` is not supported for models having added KV projections." -+ ) -+ -+ self.original_attn_processors = self.attn_processors -+ -+ for module in self.modules(): -+ if isinstance(module, Attention): -+ module.fuse_projections(fuse=True) -+ -+ self.set_attn_processor(FusedFluxAttnProcessor2_0()) -+ -+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections -+ def unfuse_qkv_projections(self): -+ """Disables the fused QKV projection if enabled. -+ -+ -+ -+ This API is 🧪 experimental. -+ -+ -+ -+ """ -+ if self.original_attn_processors is not None: -+ self.set_attn_processor(self.original_attn_processors) -+ - def forward( - self, - hidden_states: torch.Tensor, -@@ -624,11 +469,13 @@ class FluxTransformer2DModel( - img_ids: torch.Tensor = None, - txt_ids: torch.Tensor = None, - guidance: torch.Tensor = None, -+ additional_timestep_embeddings: torch.Tensor = None, - joint_attention_kwargs: Optional[Dict[str, Any]] = None, - controlnet_block_samples=None, - controlnet_single_block_samples=None, - return_dict: bool = True, - controlnet_blocks_repeat: bool = False, -+ attention_mask: Optional[torch.Tensor] = None, - ) -> Union[torch.Tensor, Transformer2DModelOutput]: - """ - The [`FluxTransformer2DModel`] forward method. -@@ -666,7 +513,10 @@ class FluxTransformer2DModel( - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - else: -- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: -+ if ( -+ joint_attention_kwargs is not None -+ and joint_attention_kwargs.get("scale", None) is not None -+ ): - logger.warning( - "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." - ) -@@ -676,12 +526,24 @@ class FluxTransformer2DModel( - timestep = timestep.to(hidden_states.dtype) * 1000 - if guidance is not None: - guidance = guidance.to(hidden_states.dtype) * 1000 -+ else: -+ guidance = None - - temb = ( - self.time_text_embed(timestep, pooled_projections) - if guidance is None - else self.time_text_embed(timestep, guidance, pooled_projections) - ) -+ -+ if additional_timestep_embeddings is not None: -+ additional_timestep_embed_proj = self.additional_time_proj( -+ additional_timestep_embeddings -+ ) -+ additional_class_emb = self.additional_timestep_embedder( -+ additional_timestep_embed_proj.to(dtype=temb.dtype) -+ ) -+ temb = temb + additional_class_emb -+ - encoder_hidden_states = self.context_embedder(encoder_hidden_states) - - if txt_ids.ndim == 3: -@@ -700,20 +562,27 @@ class FluxTransformer2DModel( - ids = torch.cat((txt_ids, img_ids), dim=0) - image_rotary_emb = self.pos_embed(ids) - -- if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: -- ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") -+ if ( -+ joint_attention_kwargs is not None -+ and "ip_adapter_image_embeds" in joint_attention_kwargs -+ ): -+ ip_adapter_image_embeds = joint_attention_kwargs.pop( -+ "ip_adapter_image_embeds" -+ ) - ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) - joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) - - for index_block, block in enumerate(self.transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: -- encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( -- block, -- hidden_states, -- encoder_hidden_states, -- temb, -- image_rotary_emb, -- joint_attention_kwargs, -+ encoder_hidden_states, hidden_states = ( -+ self._gradient_checkpointing_func( -+ block, -+ hidden_states, -+ encoder_hidden_states, -+ temb, -+ image_rotary_emb, -+ attention_mask=attention_mask, -+ ) - ) - - else: -@@ -723,45 +592,61 @@ class FluxTransformer2DModel( - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, -+ attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_block_samples is not None: -- interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) -+ interval_control = len(self.transformer_blocks) / len( -+ controlnet_block_samples -+ ) - interval_control = int(np.ceil(interval_control)) - # For Xlabs ControlNet. - if controlnet_blocks_repeat: - hidden_states = ( -- hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] -+ hidden_states -+ + controlnet_block_samples[ -+ index_block % len(controlnet_block_samples) -+ ] - ) - else: -- hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] -+ hidden_states = ( -+ hidden_states -+ + controlnet_block_samples[index_block // interval_control] -+ ) -+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) - - for index_block, block in enumerate(self.single_transformer_blocks): - if torch.is_grad_enabled() and self.gradient_checkpointing: -- encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( -+ hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, -- encoder_hidden_states, - temb, - image_rotary_emb, -- joint_attention_kwargs, -+ attention_mask=attention_mask, - ) - - else: -- encoder_hidden_states, hidden_states = block( -+ hidden_states = block( - hidden_states=hidden_states, -- encoder_hidden_states=encoder_hidden_states, - temb=temb, - image_rotary_emb=image_rotary_emb, - joint_attention_kwargs=joint_attention_kwargs, -+ attention_mask=attention_mask, - ) - - # controlnet residual - if controlnet_single_block_samples is not None: -- interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) -+ interval_control = len(self.single_transformer_blocks) / len( -+ controlnet_single_block_samples -+ ) - interval_control = int(np.ceil(interval_control)) -- hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control] -+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( -+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] -+ + controlnet_single_block_samples[index_block // interval_control] -+ ) -+ -+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] - - hidden_states = self.norm_out(hidden_states, temb) - output = self.proj_out(hidden_states) From 9a71e7f4edf6f6272d755f252901a06d787655f4 Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 30 Sep 2025 21:52:27 +0000 Subject: [PATCH 5/6] add controlnet blocks --- .../models/transformers/sana_transformer.py | 158 ++++++++++++++---- 1 file changed, 124 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/transformers/sana_transformer.py b/src/diffusers/models/transformers/sana_transformer.py index 48b731406191..54e996e13d42 100644 --- a/src/diffusers/models/transformers/sana_transformer.py +++ b/src/diffusers/models/transformers/sana_transformer.py @@ -1,4 +1,4 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. +# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,12 +26,16 @@ AttentionProcessor, SanaLinearAttnProcessor2_0, ) -from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps +from ..embeddings import ( + PatchEmbed, + PixArtAlphaTextProjection, + TimestepEmbedding, + Timesteps, +) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin from ..normalization import AdaLayerNormSingle, RMSNorm - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -52,12 +56,21 @@ def __init__( self.nonlinearity = nn.SiLU() self.conv_inverted = nn.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0) - self.conv_depth = nn.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2) + self.conv_depth = nn.Conv2d( + hidden_channels * 2, + hidden_channels * 2, + 3, + 1, + 1, + groups=hidden_channels * 2, + ) self.conv_point = nn.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False) self.norm = None if norm_type == "rms_norm": - self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True) + self.norm = RMSNorm( + out_channels, eps=1e-5, elementwise_affine=True, bias=True + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.residual_connection: @@ -88,10 +101,15 @@ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6 self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps) def forward( - self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor + self, + hidden_states: torch.Tensor, + temb: torch.Tensor, + scale_shift_table: torch.Tensor, ) -> torch.Tensor: hidden_states = self.norm(hidden_states) - shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1) + shift, scale = ( + scale_shift_table[None] + temb[:, None].to(scale_shift_table.device) + ).chunk(2, dim=1) hidden_states = hidden_states * (1 + scale) + shift return hidden_states @@ -99,18 +117,33 @@ def forward( class SanaCombinedTimestepGuidanceEmbeddings(nn.Module): def __init__(self, embedding_dim): super().__init__() - self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.time_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.timestep_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) - self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) - self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) + self.guidance_condition_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.guidance_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=embedding_dim + ) self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) - def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None): + def forward( + self, + timestep: torch.Tensor, + guidance: torch.Tensor = None, + hidden_dtype: torch.dtype = None, + ): timesteps_proj = self.time_proj(timestep) - timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + timesteps_emb = self.timestep_embedder( + timesteps_proj.to(dtype=hidden_dtype) + ) # (N, D) guidance_proj = self.guidance_condition_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype)) @@ -126,7 +159,9 @@ class SanaAttnProcessor2_0: def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): - raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + raise ImportError( + "SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) def __call__( self, @@ -136,14 +171,20 @@ def __call__( attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape ) if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) query = attn.to_q(hidden_states) @@ -172,7 +213,9 @@ def __call__( query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) hidden_states = hidden_states.to(query.dtype) # linear proj @@ -224,7 +267,9 @@ def __init__( # 2. Cross Attention if cross_attention_dim is not None: - self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + self.norm2 = nn.LayerNorm( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) self.attn2 = Attention( query_dim=dim, qk_norm=qk_norm, @@ -239,7 +284,9 @@ def __init__( ) # 3. Feed-forward - self.ff = GLUMBConv(dim, dim, mlp_ratio, norm_type=None, residual_connection=False) + self.ff = GLUMBConv( + dim, dim, mlp_ratio, norm_type=None, residual_connection=False + ) self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) @@ -281,7 +328,9 @@ def forward( norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp - norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute(0, 3, 1, 2) + norm_hidden_states = norm_hidden_states.unflatten(1, (height, width)).permute( + 0, 3, 1, 2 + ) ff_output = self.ff(norm_hidden_states) ff_output = ff_output.flatten(2, 3).permute(0, 2, 1) hidden_states = hidden_states + gate_mlp * ff_output @@ -289,7 +338,9 @@ def forward( return hidden_states -class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): +class SanaTransformer2DModel( + ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin +): r""" A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models. @@ -383,7 +434,9 @@ def __init__( else: self.time_embed = AdaLayerNormSingle(inner_dim) - self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True) # 3. Transformer blocks @@ -408,7 +461,9 @@ def __init__( ) # 4. Output blocks - self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) @@ -425,7 +480,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -440,7 +499,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): r""" Sets the attention processor to use to compute attention. @@ -483,6 +544,7 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples: Optional[Tuple[torch.Tensor]] = None, return_dict: bool = True, ) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]: if attention_kwargs is not None: @@ -495,7 +557,10 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + if ( + attention_kwargs is not None + and attention_kwargs.get("scale", None) is not None + ): logger.warning( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) @@ -520,7 +585,9 @@ def forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: - encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 1. Input @@ -540,13 +607,15 @@ def forward( ) encoder_hidden_states = self.caption_projection(encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) encoder_hidden_states = self.caption_norm(encoder_hidden_states) # 2. Transformer blocks if torch.is_grad_enabled() and self.gradient_checkpointing: - for block in self.transformer_blocks: + for index_block, block in enumerate(self.transformer_blocks): hidden_states = self._gradient_checkpointing_func( block, hidden_states, @@ -557,9 +626,15 @@ def forward( post_patch_height, post_patch_width, ) + if controlnet_block_samples is not None and 0 < index_block <= len( + controlnet_block_samples + ): + hidden_states = ( + hidden_states + controlnet_block_samples[index_block - 1] + ) else: - for block in self.transformer_blocks: + for index_block, block in enumerate(self.transformer_blocks): hidden_states = block( hidden_states, attention_mask, @@ -569,18 +644,33 @@ def forward( post_patch_height, post_patch_width, ) + if controlnet_block_samples is not None and 0 < index_block <= len( + controlnet_block_samples + ): + hidden_states = ( + hidden_states + controlnet_block_samples[index_block - 1] + ) # 3. Normalization - hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table) + hidden_states = self.norm_out( + hidden_states, embedded_timestep, self.scale_shift_table + ) hidden_states = self.proj_out(hidden_states) # 5. Unpatchify hidden_states = hidden_states.reshape( - batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1 + batch_size, + post_patch_height, + post_patch_width, + self.config.patch_size, + self.config.patch_size, + -1, ) hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4) - output = hidden_states.reshape(batch_size, -1, post_patch_height * p, post_patch_width * p) + output = hidden_states.reshape( + batch_size, -1, post_patch_height * p, post_patch_width * p + ) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer From d6816ffc9257b1489886b385be34aee2ce5240bf Mon Sep 17 00:00:00 2001 From: onurxtasar Date: Tue, 14 Oct 2025 11:29:19 +0000 Subject: [PATCH 6/6] move the sana controlnet import next to other conrolnet imports --- src/diffusers/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index f300dda7fd32..719325de13ef 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -24,7 +24,6 @@ _import_structure = {} if is_torch_available(): - _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"] _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -51,6 +50,7 @@ _import_structure["autoencoders.vq_model"] = ["VQModel"] _import_structure["cache_utils"] = ["CacheMixin"] _import_structure["controlnets.controlnet"] = ["ControlNetModel"] + _import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"] _import_structure["controlnets.controlnet_flux"] = [ "FluxControlNetModel", "FluxMultiControlNetModel",