From e91f81e102f8ee8fadc81ca2efcc8255bbb19e15 Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Mon, 30 Jun 2025 09:46:50 -0700 Subject: [PATCH 1/4] Adds FLUX attention masking and additional time embedder --- .../models/transformers/transformer_flux.py | 176 ++++++++++++++---- .../models/unets/unet_2d_condition.py | 23 +++ 2 files changed, 165 insertions(+), 34 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 87537890d246..a36f4b5088f7 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -20,7 +20,11 @@ import torch.nn as nn from ...configuration_utils import ConfigMixin, register_to_config -from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin +from ...loaders import ( + FluxTransformer2DLoadersMixin, + FromOriginalModelMixin, + PeftAdapterMixin, +) from ...models.attention import FeedForward from ...models.attention_processor import ( Attention, @@ -30,21 +34,42 @@ FusedFluxAttnProcessor2_0, ) from ...models.modeling_utils import ModelMixin -from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle -from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...models.normalization import ( + AdaLayerNormContinuous, + AdaLayerNormZero, + AdaLayerNormZeroSingle, +) +from ...utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ...utils.import_utils import is_torch_npu_available from ...utils.torch_utils import maybe_allow_in_graph from ..cache_utils import CacheMixin -from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed +from ..embeddings import ( + CombinedTimestepGuidanceTextProjEmbeddings, + CombinedTimestepTextProjEmbeddings, + FluxPosEmbed, + TimestepEmbedding, + Timesteps, +) from ..modeling_outputs import Transformer2DModelOutput - logger = logging.get_logger(__name__) # pylint: disable=invalid-name @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): - def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + mlp_ratio: float = 4.0, + ): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) @@ -82,6 +107,7 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) @@ -90,6 +116,7 @@ def forward( attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -106,7 +133,12 @@ def forward( @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): def __init__( - self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6 + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + qk_norm: str = "rms_norm", + eps: float = 1e-6, ): super().__init__() @@ -131,7 +163,9 @@ def __init__( self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) - self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") + self.ff_context = FeedForward( + dim=dim, dim_out=dim, activation_fn="gelu-approximate" + ) def forward( self, @@ -140,11 +174,14 @@ def forward( temb: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, + attention_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, emb=temb + ) - norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( - encoder_hidden_states, emb=temb + norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( + self.norm1_context(encoder_hidden_states, emb=temb) ) joint_attention_kwargs = joint_attention_kwargs or {} # Attention. @@ -152,6 +189,7 @@ def forward( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, + attention_mask=attention_mask, **joint_attention_kwargs, ) @@ -165,7 +203,9 @@ def forward( hidden_states = hidden_states + attn_output norm_hidden_states = self.norm2(hidden_states) - norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + norm_hidden_states = ( + norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + ) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output @@ -180,10 +220,15 @@ def forward( encoder_hidden_states = encoder_hidden_states + context_attn_output norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) - norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] + norm_encoder_hidden_states = ( + norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + + c_shift_mlp[:, None] + ) context_ff_output = self.ff_context(norm_encoder_hidden_states) - encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + encoder_hidden_states = ( + encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output + ) if encoder_hidden_states.dtype == torch.float16: encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504) @@ -191,7 +236,12 @@ def forward( class FluxTransformer2DModel( - ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin + ModelMixin, + ConfigMixin, + PeftAdapterMixin, + FromOriginalModelMixin, + FluxTransformer2DLoadersMixin, + CacheMixin, ): """ The Transformer model introduced in Flux. @@ -241,6 +291,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, + discretization_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() @@ -250,12 +301,23 @@ def __init__( self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( - CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings + CombinedTimestepGuidanceTextProjEmbeddings + if guidance_embeds + else CombinedTimestepTextProjEmbeddings ) + self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) + if discretization_embeds: + self.disc_proj = Timesteps( + num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + self.disc_embedder = TimestepEmbedding( + in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False + ) + self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(in_channels, self.inner_dim) @@ -281,8 +343,12 @@ def __init__( ] ) - self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) - self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) + self.norm_out = AdaLayerNormContinuous( + self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6 + ) + self.proj_out = nn.Linear( + self.inner_dim, patch_size * patch_size * self.out_channels, bias=True + ) self.gradient_checkpointing = False @@ -297,7 +363,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -312,7 +382,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): r""" Sets the attention processor to use to compute attention. @@ -362,7 +434,9 @@ def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) self.original_attn_processors = self.attn_processors @@ -395,11 +469,13 @@ def forward( img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, + discretization: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False, + attention_mask: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. @@ -437,7 +513,10 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: - if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + if ( + joint_attention_kwargs is not None + and joint_attention_kwargs.get("scale", None) is not None + ): logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) @@ -455,6 +534,14 @@ def forward( if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) + + if discretization is not None: + discretization = self.disc_proj( + discretization + ) # * (discretization > 0).unsqueeze(1) + discretization = self.disc_embedder(discretization.to(dtype=temb.dtype)) + temb = temb + discretization + encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: @@ -473,19 +560,27 @@ def forward( ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) - if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: - ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + if ( + joint_attention_kwargs is not None + and "ip_adapter_image_embeds" in joint_attention_kwargs + ): + ip_adapter_image_embeds = joint_attention_kwargs.pop( + "ip_adapter_image_embeds" + ) ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: - encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( - block, - hidden_states, - encoder_hidden_states, - temb, - image_rotary_emb, + encoder_hidden_states, hidden_states = ( + self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + attention_mask=attention_mask, + ) ) else: @@ -495,19 +590,28 @@ def forward( temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + attention_mask=attention_mask, ) # controlnet residual if controlnet_block_samples is not None: - interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = len(self.transformer_blocks) / len( + controlnet_block_samples + ) interval_control = int(np.ceil(interval_control)) # For Xlabs ControlNet. if controlnet_blocks_repeat: hidden_states = ( - hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + hidden_states + + controlnet_block_samples[ + index_block % len(controlnet_block_samples) + ] ) else: - hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = ( + hidden_states + + controlnet_block_samples[index_block // interval_control] + ) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): @@ -517,6 +621,7 @@ def forward( hidden_states, temb, image_rotary_emb, + attention_mask=attention_mask, ) else: @@ -525,11 +630,14 @@ def forward( temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, + attention_mask=attention_mask, ) # controlnet residual if controlnet_single_block_samples is not None: - interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = len(self.single_transformer_blocks) / len( + controlnet_single_block_samples + ) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...] diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 2fd15f6f91e0..5a16a4374479 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -235,6 +235,7 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, + additional_class_embeddings: bool = False, ): super().__init__() @@ -293,6 +294,17 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) + if additional_class_embeddings: + self.additional_class_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=None, + ) + else: + self.additional_class_embedding = None + self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, @@ -1170,6 +1182,7 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + additional_class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1296,6 +1309,16 @@ def forward( else: emb = emb + class_emb + if additional_class_labels is not None: + additional_class_emb_proj = self.time_proj(additional_class_labels).to( + dtype=emb.dtype + ) + additional_class_emb = self.additional_class_embedding( + additional_class_emb_proj + ) + emb = emb + additional_class_emb + print("additional_class_emb", emb.shape) + aug_emb = self.get_aug_embed( emb=emb, encoder_hidden_states=encoder_hidden_states, From 1e174f1f16fa51f3835f11088242a16f5e667aaa Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Tue, 1 Jul 2025 09:27:48 -0700 Subject: [PATCH 2/4] Add requested changes --- .../models/transformers/transformer_flux.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index a36f4b5088f7..b64920c374f4 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -291,7 +291,7 @@ def __init__( joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, - discretization_embeds: bool = False, + additional_timestep_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), ): super().__init__() @@ -310,11 +310,11 @@ def __init__( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) - if discretization_embeds: - self.disc_proj = Timesteps( + if additional_timestep_embeds: + self.additional_time_proj = Timesteps( num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0 ) - self.disc_embedder = TimestepEmbedding( + self.additional_timestep_embedder = TimestepEmbedding( in_channels=256, time_embed_dim=self.inner_dim, sample_proj_bias=False ) @@ -469,7 +469,7 @@ def forward( img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, - discretization: torch.Tensor = None, + additional_timestep_embeddings: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, @@ -535,12 +535,14 @@ def forward( else self.time_text_embed(timestep, guidance, pooled_projections) ) - if discretization is not None: - discretization = self.disc_proj( - discretization - ) # * (discretization > 0).unsqueeze(1) - discretization = self.disc_embedder(discretization.to(dtype=temb.dtype)) - temb = temb + discretization + if additional_timestep_embeddings is not None: + additional_timestep_embed_proj = self.additional_time_proj( + additional_timestep_embeddings + ) + additional_class_emb = self.additional_timestep_embedder( + additional_timestep_embed_proj.to(dtype=temb.dtype) + ) + temb = temb + additional_class_emb encoder_hidden_states = self.context_embedder(encoder_hidden_states) From 693b52bfe4e7876a73db5311d1d5416e8baa2cf0 Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Tue, 1 Jul 2025 09:33:10 -0700 Subject: [PATCH 3/4] Revert unet 2d condition --- .../models/unets/unet_2d_condition.py | 358 ++++-------------- 1 file changed, 71 insertions(+), 287 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5a16a4374479..5674d8ba26ec 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -21,14 +21,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import ( - USE_PEFT_BACKEND, - BaseOutput, - deprecate, - logging, - scale_lora_layers, - unscale_lora_layers, -) +from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers from ..activations import get_activation from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -52,7 +45,12 @@ Timesteps, ) from ..modeling_utils import ModelMixin -from .unet_2d_blocks import get_down_block, get_mid_block, get_up_block +from .unet_2d_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -71,11 +69,7 @@ class UNet2DConditionOutput(BaseOutput): class UNet2DConditionModel( - ModelMixin, - ConfigMixin, - FromOriginalModelMixin, - UNet2DConditionLoadersMixin, - PeftAdapterMixin, + ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin ): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample @@ -190,12 +184,7 @@ def __init__( "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - "CrossAttnUpBlock2D", - ), + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, @@ -235,7 +224,6 @@ def __init__( mid_block_only_cross_attention: Optional[bool] = None, cross_attention_norm: Optional[str] = None, addition_embed_type_num_heads: int = 64, - additional_class_embeddings: bool = False, ): super().__init__() @@ -271,10 +259,7 @@ def __init__( # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( - in_channels, - block_out_channels[0], - kernel_size=conv_in_kernel, - padding=conv_in_padding, + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding ) # time @@ -294,17 +279,6 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - if additional_class_embeddings: - self.additional_class_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=None, - ) - else: - self.additional_class_embedding = None - self._set_encoder_hid_proj( encoder_hid_dim_type, cross_attention_dim=cross_attention_dim, @@ -363,9 +337,7 @@ def __init__( layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len( - down_block_types - ) + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the @@ -405,11 +377,7 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, - attention_head_dim=( - attention_head_dim[i] - if attention_head_dim[i] is not None - else output_channel - ), + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.down_blocks.append(down_block) @@ -459,9 +427,7 @@ def __init__( prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[ - min(i + 1, len(block_out_channels) - 1) - ] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] # add upsample block for all BUT final layer if not is_final_block: @@ -494,11 +460,7 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, - attention_head_dim=( - attention_head_dim[i] - if attention_head_dim[i] is not None - else output_channel - ), + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, dropout=dropout, ) self.up_blocks.append(up_block) @@ -506,9 +468,7 @@ def __init__( # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], - num_groups=norm_num_groups, - eps=norm_eps, + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps ) self.conv_act = get_activation(act_fn) @@ -519,15 +479,10 @@ def __init__( conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( - block_out_channels[0], - out_channels, - kernel_size=conv_out_kernel, - padding=conv_out_padding, + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding ) - self._set_pos_net_if_use_gligen( - attention_type=attention_type, cross_attention_dim=cross_attention_dim - ) + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) def _check_config( self, @@ -552,49 +507,34 @@ def _check_config( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - if not isinstance(only_cross_attention, bool) and len( - only_cross_attention - ) != len(down_block_types): + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( - down_block_types - ): + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( - down_block_types - ): + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( - down_block_types - ): + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) - if not isinstance(layers_per_block, int) and len(layers_per_block) != len( - down_block_types - ): + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - if ( - isinstance(transformer_layers_per_block, list) - and reverse_transformer_layers_per_block is None - ): + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): - raise ValueError( - "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." - ) + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") def _set_time_proj( self, @@ -607,22 +547,15 @@ def _set_time_proj( if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: - raise ValueError( - f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." - ) + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, - set_W_to_weight=False, - log=False, - flip_sin_to_cos=flip_sin_to_cos, + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_proj = Timesteps( - block_out_channels[0], flip_sin_to_cos, freq_shift - ) + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] else: raise ValueError( @@ -640,9 +573,7 @@ def _set_encoder_hid_proj( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info( - "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." - ) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( @@ -685,9 +616,7 @@ def _set_class_embedding( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding( - timestep_input_dim, time_embed_dim, act_fn=act_fn - ) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -702,17 +631,13 @@ def _set_class_embedding( # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding( - projection_class_embeddings_input_dim, time_embed_dim - ) + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) - self.class_embedding = nn.Linear( - projection_class_embeddings_input_dim, time_embed_dim - ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None @@ -735,36 +660,24 @@ def _set_add_embedding( text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, - time_embed_dim, - num_heads=addition_embed_type_num_heads, + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, - image_embed_dim=cross_attention_dim, - time_embed_dim=time_embed_dim, + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim ) elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps( - addition_time_embed_dim, flip_sin_to_cos, freq_shift - ) - self.add_embedding = TimestepEmbedding( - projection_class_embeddings_input_dim, time_embed_dim - ) + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) elif addition_embed_type == "image": # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding( - image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim - ) + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding( - image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim - ) + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) elif addition_embed_type is not None: raise ValueError( f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." @@ -780,9 +693,7 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, - out_dim=cross_attention_dim, - feature_type=feature_type, + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type ) @property @@ -795,11 +706,7 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors( - name: str, - module: torch.nn.Module, - processors: Dict[str, AttentionProcessor], - ): + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -813,9 +720,7 @@ def fn_recursive_add_processors( return processors - def set_attn_processor( - self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] - ): + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): r""" Sets the attention processor to use to compute attention. @@ -853,15 +758,9 @@ def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ - if all( - proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS - for proc in self.attn_processors.values() - ): + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnAddedKVProcessor() - elif all( - proc.__class__ in CROSS_ATTENTION_PROCESSORS - for proc in self.attn_processors.values() - ): + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): processor = AttnProcessor() else: raise ValueError( @@ -907,11 +806,7 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): # make smallest slice possible slice_size = num_sliceable_layers * [1] - slice_size = ( - num_sliceable_layers * [slice_size] - if not isinstance(slice_size, list) - else slice_size - ) + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -928,9 +823,7 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message - def fn_recursive_set_attention_slice( - module: torch.nn.Module, slice_size: List[int] - ): + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) @@ -970,10 +863,7 @@ def disable_freeu(self): freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: - if ( - hasattr(upsample_block, k) - or getattr(upsample_block, k, None) is not None - ): + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: setattr(upsample_block, k, None) def fuse_qkv_projections(self): @@ -991,9 +881,7 @@ def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): - raise ValueError( - "`fuse_qkv_projections()` is not supported for models having added KV projections." - ) + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors @@ -1043,15 +931,11 @@ def get_time_embed( t_emb = t_emb.to(dtype=sample.dtype) return t_emb - def get_class_embed( - self, sample: torch.Tensor, class_labels: Optional[torch.Tensor] - ) -> Optional[torch.Tensor]: + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: class_emb = None if self.class_embedding is not None: if class_labels is None: - raise ValueError( - "class_labels should be provided when num_class_embeds > 0" - ) + raise ValueError("class_labels should be provided when num_class_embeds > 0") if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) @@ -1064,10 +948,7 @@ def get_class_embed( return class_emb def get_aug_embed( - self, - emb: torch.Tensor, - encoder_hidden_states: torch.Tensor, - added_cond_kwargs: Dict[str, Any], + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> Optional[torch.Tensor]: aug_emb = None if self.config.addition_embed_type == "text": @@ -1109,10 +990,7 @@ def get_aug_embed( aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet - style - if ( - "image_embeds" not in added_cond_kwargs - or "hint" not in added_cond_kwargs - ): + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) @@ -1124,15 +1002,9 @@ def get_aug_embed( def process_encoder_hidden_states( self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> torch.Tensor: - if ( - self.encoder_hid_proj is not None - and self.config.encoder_hid_dim_type == "text_proj" - ): + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif ( - self.encoder_hid_proj is not None - and self.config.encoder_hid_dim_type == "text_image_proj" - ): + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( @@ -1140,13 +1012,8 @@ def process_encoder_hidden_states( ) image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj( - encoder_hidden_states, image_embeds - ) - elif ( - self.encoder_hid_proj is not None - and self.config.encoder_hid_dim_type == "image_proj" - ): + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( @@ -1154,22 +1021,14 @@ def process_encoder_hidden_states( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif ( - self.encoder_hid_proj is not None - and self.config.encoder_hid_dim_type == "ip_image_proj" - ): + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) - if ( - hasattr(self, "text_encoder_hid_proj") - and self.text_encoder_hid_proj is not None - ): - encoder_hidden_states = self.text_encoder_hid_proj( - encoder_hidden_states - ) + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) @@ -1182,7 +1041,6 @@ def forward( timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, - additional_class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -1192,8 +1050,6 @@ def forward( down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, - return_post_down_blocks: bool = False, - return_post_mid_blocks: bool = False, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. @@ -1233,22 +1089,6 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks for - example from ControlNet side model(s) - mid_block_additional_residual (`torch.Tensor`, *optional*): - additional residual to be added to UNet mid block output, for example from ControlNet side model - down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) - return_post_down_blocks (`bool`, *optional*, defaults to `False`): - Whether or not to return the post down blocks. - return_post_mid_blocks (`bool`, *optional*, defaults to `False`): - Whether or not to return the post mid blocks. Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1289,9 +1129,7 @@ def forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: - encoder_attention_mask = ( - 1 - encoder_attention_mask.to(sample.dtype) - ) * -10000.0 + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary @@ -1309,20 +1147,8 @@ def forward( else: emb = emb + class_emb - if additional_class_labels is not None: - additional_class_emb_proj = self.time_proj(additional_class_labels).to( - dtype=emb.dtype - ) - additional_class_emb = self.additional_class_embedding( - additional_class_emb_proj - ) - emb = emb + additional_class_emb - print("additional_class_emb", emb.shape) - aug_emb = self.get_aug_embed( - emb=emb, - encoder_hidden_states=encoder_hidden_states, - added_cond_kwargs=added_cond_kwargs, + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb @@ -1334,23 +1160,17 @@ def forward( emb = self.time_embed_act(emb) encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, - added_cond_kwargs=added_cond_kwargs, + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs ) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net - if ( - cross_attention_kwargs is not None - and cross_attention_kwargs.get("gligen", None) is not None - ): + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = { - "objs": self.position_net(**gligen_args) - } + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} # 3. down # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated @@ -1365,20 +1185,13 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) - is_controlnet = ( - mid_block_additional_residual is not None - and down_block_additional_residuals is not None - ) + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other - if ( - not is_adapter - and mid_block_additional_residual is None - and down_block_additional_residuals is not None - ): + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", @@ -1392,16 +1205,11 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if ( - hasattr(downsample_block, "has_cross_attention") - and downsample_block.has_cross_attention - ): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = ( - down_intrablock_additional_residuals.pop(0) - ) + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) sample, res_samples = downsample_block( hidden_states=sample, @@ -1425,27 +1233,14 @@ def forward( for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): - down_block_res_sample = ( - down_block_res_sample + down_block_additional_residual - ) - new_down_block_res_samples = new_down_block_res_samples + ( - down_block_res_sample, - ) + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) down_block_res_samples = new_down_block_res_samples - if return_post_down_blocks: - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - # 4. mid if self.mid_block is not None: - if ( - hasattr(self.mid_block, "has_cross_attention") - and self.mid_block.has_cross_attention - ): + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: sample = self.mid_block( sample, emb, @@ -1468,30 +1263,19 @@ def forward( if is_controlnet: sample = sample + mid_block_additional_residual - if return_post_mid_blocks: - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) - # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[ - : -len(upsample_block.resnets) - ] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if ( - hasattr(upsample_block, "has_cross_attention") - and upsample_block.has_cross_attention - ): + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: sample = upsample_block( hidden_states=sample, temb=emb, From cec91d7e3482a936be4f39f76fcbb3874263f5d7 Mon Sep 17 00:00:00 2001 From: clementchadebec Date: Tue, 1 Jul 2025 09:34:18 -0700 Subject: [PATCH 4/4] revert unet 2d condition --- .../models/unets/unet_2d_condition.py | 335 ++++++++++++++---- 1 file changed, 264 insertions(+), 71 deletions(-) diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 5674d8ba26ec..2fd15f6f91e0 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -21,7 +21,14 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin from ...loaders.single_file_model import FromOriginalModelMixin -from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from ...utils import ( + USE_PEFT_BACKEND, + BaseOutput, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) from ..activations import get_activation from ..attention_processor import ( ADDED_KV_ATTENTION_PROCESSORS, @@ -45,12 +52,7 @@ Timesteps, ) from ..modeling_utils import ModelMixin -from .unet_2d_blocks import ( - get_down_block, - get_mid_block, - get_up_block, -) - +from .unet_2d_blocks import get_down_block, get_mid_block, get_up_block logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -69,7 +71,11 @@ class UNet2DConditionOutput(BaseOutput): class UNet2DConditionModel( - ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin + ModelMixin, + ConfigMixin, + FromOriginalModelMixin, + UNet2DConditionLoadersMixin, + PeftAdapterMixin, ): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample @@ -184,7 +190,12 @@ def __init__( "DownBlock2D", ), mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + up_block_types: Tuple[str] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), layers_per_block: Union[int, Tuple[int]] = 2, @@ -259,7 +270,10 @@ def __init__( # input conv_in_padding = (conv_in_kernel - 1) // 2 self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + in_channels, + block_out_channels[0], + kernel_size=conv_in_kernel, + padding=conv_in_padding, ) # time @@ -337,7 +351,9 @@ def __init__( layers_per_block = [layers_per_block] * len(down_block_types) if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + transformer_layers_per_block = [transformer_layers_per_block] * len( + down_block_types + ) if class_embeddings_concat: # The time embeddings are concatenated with the class embeddings. The dimension of the @@ -377,7 +393,11 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + attention_head_dim=( + attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel + ), dropout=dropout, ) self.down_blocks.append(down_block) @@ -427,7 +447,9 @@ def __init__( prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + input_channel = reversed_block_out_channels[ + min(i + 1, len(block_out_channels) - 1) + ] # add upsample block for all BUT final layer if not is_final_block: @@ -460,7 +482,11 @@ def __init__( resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + attention_head_dim=( + attention_head_dim[i] + if attention_head_dim[i] is not None + else output_channel + ), dropout=dropout, ) self.up_blocks.append(up_block) @@ -468,7 +494,9 @@ def __init__( # out if norm_num_groups is not None: self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + num_channels=block_out_channels[0], + num_groups=norm_num_groups, + eps=norm_eps, ) self.conv_act = get_activation(act_fn) @@ -479,10 +507,15 @@ def __init__( conv_out_padding = (conv_out_kernel - 1) // 2 self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + block_out_channels[0], + out_channels, + kernel_size=conv_out_kernel, + padding=conv_out_padding, ) - self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + self._set_pos_net_if_use_gligen( + attention_type=attention_type, cross_attention_dim=cross_attention_dim + ) def _check_config( self, @@ -507,34 +540,49 @@ def _check_config( f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." ) - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + if not isinstance(only_cross_attention, bool) and len( + only_cross_attention + ) != len(down_block_types): raise ValueError( f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." ) - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( + down_block_types + ): raise ValueError( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( + down_block_types + ): raise ValueError( f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." ) - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( + down_block_types + ): raise ValueError( f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." ) - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + if not isinstance(layers_per_block, int) and len(layers_per_block) != len( + down_block_types + ): raise ValueError( f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + if ( + isinstance(transformer_layers_per_block, list) + and reverse_transformer_layers_per_block is None + ): for layer_number_per_block in transformer_layers_per_block: if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + raise ValueError( + "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." + ) def _set_time_proj( self, @@ -547,15 +595,22 @@ def _set_time_proj( if time_embedding_type == "fourier": time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + raise ValueError( + f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." + ) self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + time_embed_dim // 2, + set_W_to_weight=False, + log=False, + flip_sin_to_cos=flip_sin_to_cos, ) timestep_input_dim = time_embed_dim elif time_embedding_type == "positional": time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos, freq_shift + ) timestep_input_dim = block_out_channels[0] else: raise ValueError( @@ -573,7 +628,9 @@ def _set_encoder_hid_proj( if encoder_hid_dim_type is None and encoder_hid_dim is not None: encoder_hid_dim_type = "text_proj" self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + logger.info( + "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." + ) if encoder_hid_dim is None and encoder_hid_dim_type is not None: raise ValueError( @@ -616,7 +673,9 @@ def _set_class_embedding( if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + self.class_embedding = TimestepEmbedding( + timestep_input_dim, time_embed_dim, act_fn=act_fn + ) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) elif class_embed_type == "projection": @@ -631,13 +690,17 @@ def _set_class_embedding( # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) elif class_embed_type == "simple_projection": if projection_class_embeddings_input_dim is None: raise ValueError( "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + self.class_embedding = nn.Linear( + projection_class_embeddings_input_dim, time_embed_dim + ) else: self.class_embedding = None @@ -660,24 +723,36 @@ def _set_add_embedding( text_time_embedding_from_dim = cross_attention_dim self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + text_time_embedding_from_dim, + time_embed_dim, + num_heads=addition_embed_type_num_heads, ) elif addition_embed_type == "text_image": # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)` self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + text_embed_dim=cross_attention_dim, + image_embed_dim=cross_attention_dim, + time_embed_dim=time_embed_dim, ) elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + self.add_time_proj = Timesteps( + addition_time_embed_dim, flip_sin_to_cos, freq_shift + ) + self.add_embedding = TimestepEmbedding( + projection_class_embeddings_input_dim, time_embed_dim + ) elif addition_embed_type == "image": # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + self.add_embedding = ImageTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) elif addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + self.add_embedding = ImageHintTimeEmbedding( + image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim + ) elif addition_embed_type is not None: raise ValueError( f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'." @@ -693,7 +768,9 @@ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: i feature_type = "text-only" if attention_type == "gated" else "text-image" self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + positive_len=positive_len, + out_dim=cross_attention_dim, + feature_type=feature_type, ) @property @@ -706,7 +783,11 @@ def attn_processors(self) -> Dict[str, AttentionProcessor]: # set recursively processors = {} - def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + def fn_recursive_add_processors( + name: str, + module: torch.nn.Module, + processors: Dict[str, AttentionProcessor], + ): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() @@ -720,7 +801,9 @@ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: return processors - def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]] + ): r""" Sets the attention processor to use to compute attention. @@ -758,9 +841,15 @@ def set_default_attn_processor(self): """ Disables custom attention processors and sets the default attention implementation. """ - if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + if all( + proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): processor = AttnAddedKVProcessor() - elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + elif all( + proc.__class__ in CROSS_ATTENTION_PROCESSORS + for proc in self.attn_processors.values() + ): processor = AttnProcessor() else: raise ValueError( @@ -806,7 +895,11 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): # make smallest slice possible slice_size = num_sliceable_layers * [1] - slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + slice_size = ( + num_sliceable_layers * [slice_size] + if not isinstance(slice_size, list) + else slice_size + ) if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -823,7 +916,9 @@ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): # Recursively walk through all the children. # Any children which exposes the set_attention_slice method # gets the message - def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + def fn_recursive_set_attention_slice( + module: torch.nn.Module, slice_size: List[int] + ): if hasattr(module, "set_attention_slice"): module.set_attention_slice(slice_size.pop()) @@ -863,7 +958,10 @@ def disable_freeu(self): freeu_keys = {"s1", "s2", "b1", "b2"} for i, upsample_block in enumerate(self.up_blocks): for k in freeu_keys: - if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + if ( + hasattr(upsample_block, k) + or getattr(upsample_block, k, None) is not None + ): setattr(upsample_block, k, None) def fuse_qkv_projections(self): @@ -881,7 +979,9 @@ def fuse_qkv_projections(self): for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): - raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + raise ValueError( + "`fuse_qkv_projections()` is not supported for models having added KV projections." + ) self.original_attn_processors = self.attn_processors @@ -931,11 +1031,15 @@ def get_time_embed( t_emb = t_emb.to(dtype=sample.dtype) return t_emb - def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + def get_class_embed( + self, sample: torch.Tensor, class_labels: Optional[torch.Tensor] + ) -> Optional[torch.Tensor]: class_emb = None if self.class_embedding is not None: if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") + raise ValueError( + "class_labels should be provided when num_class_embeds > 0" + ) if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) @@ -948,7 +1052,10 @@ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Ten return class_emb def get_aug_embed( - self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + self, + emb: torch.Tensor, + encoder_hidden_states: torch.Tensor, + added_cond_kwargs: Dict[str, Any], ) -> Optional[torch.Tensor]: aug_emb = None if self.config.addition_embed_type == "text": @@ -990,7 +1097,10 @@ def get_aug_embed( aug_emb = self.add_embedding(image_embs) elif self.config.addition_embed_type == "image_hint": # Kandinsky 2.2 ControlNet - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + if ( + "image_embeds" not in added_cond_kwargs + or "hint" not in added_cond_kwargs + ): raise ValueError( f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" ) @@ -1002,9 +1112,15 @@ def get_aug_embed( def process_encoder_hidden_states( self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] ) -> torch.Tensor: - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + ): encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_image_proj" + ): # Kandinsky 2.1 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( @@ -1012,8 +1128,13 @@ def process_encoder_hidden_states( ) image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + encoder_hidden_states = self.encoder_hid_proj( + encoder_hidden_states, image_embeds + ) + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "image_proj" + ): # Kandinsky 2.2 - style if "image_embeds" not in added_cond_kwargs: raise ValueError( @@ -1021,14 +1142,22 @@ def process_encoder_hidden_states( ) image_embeds = added_cond_kwargs.get("image_embeds") encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + elif ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "ip_image_proj" + ): if "image_embeds" not in added_cond_kwargs: raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" ) - if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: - encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + if ( + hasattr(self, "text_encoder_hid_proj") + and self.text_encoder_hid_proj is not None + ): + encoder_hidden_states = self.text_encoder_hid_proj( + encoder_hidden_states + ) image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) @@ -1050,6 +1179,8 @@ def forward( down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, + return_post_down_blocks: bool = False, + return_post_mid_blocks: bool = False, ) -> Union[UNet2DConditionOutput, Tuple]: r""" The [`UNet2DConditionModel`] forward method. @@ -1089,6 +1220,22 @@ def forward( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + return_post_down_blocks (`bool`, *optional*, defaults to `False`): + Whether or not to return the post down blocks. + return_post_mid_blocks (`bool`, *optional*, defaults to `False`): + Whether or not to return the post mid blocks. Returns: [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: @@ -1129,7 +1276,9 @@ def forward( # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(sample.dtype) + ) * -10000.0 encoder_attention_mask = encoder_attention_mask.unsqueeze(1) # 0. center input if necessary @@ -1148,7 +1297,9 @@ def forward( emb = emb + class_emb aug_emb = self.get_aug_embed( - emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + emb=emb, + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) if self.config.addition_embed_type == "image_hint": aug_emb, hint = aug_emb @@ -1160,17 +1311,23 @@ def forward( emb = self.time_embed_act(emb) encoder_hidden_states = self.process_encoder_hidden_states( - encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + encoder_hidden_states=encoder_hidden_states, + added_cond_kwargs=added_cond_kwargs, ) # 2. pre-process sample = self.conv_in(sample) # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + if ( + cross_attention_kwargs is not None + and cross_attention_kwargs.get("gligen", None) is not None + ): cross_attention_kwargs = cross_attention_kwargs.copy() gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + cross_attention_kwargs["gligen"] = { + "objs": self.position_net(**gligen_args) + } # 3. down # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated @@ -1185,13 +1342,20 @@ def forward( # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_controlnet = ( + mid_block_additional_residual is not None + and down_block_additional_residuals is not None + ) # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets is_adapter = down_intrablock_additional_residuals is not None # maintain backward compatibility for legacy usage, where # T2I-Adapter and ControlNet both use down_block_additional_residuals arg # but can only use one or the other - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + if ( + not is_adapter + and mid_block_additional_residual is None + and down_block_additional_residuals is not None + ): deprecate( "T2I should not use down_block_additional_residuals", "1.3.0", @@ -1205,11 +1369,16 @@ def forward( down_block_res_samples = (sample,) for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + if ( + hasattr(downsample_block, "has_cross_attention") + and downsample_block.has_cross_attention + ): # For t2i-adapter CrossAttnDownBlock2D additional_residuals = {} if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + additional_residuals["additional_residuals"] = ( + down_intrablock_additional_residuals.pop(0) + ) sample, res_samples = downsample_block( hidden_states=sample, @@ -1233,14 +1402,27 @@ def forward( for down_block_res_sample, down_block_additional_residual in zip( down_block_res_samples, down_block_additional_residuals ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + down_block_res_sample = ( + down_block_res_sample + down_block_additional_residual + ) + new_down_block_res_samples = new_down_block_res_samples + ( + down_block_res_sample, + ) down_block_res_samples = new_down_block_res_samples + if return_post_down_blocks: + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + # 4. mid if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + if ( + hasattr(self.mid_block, "has_cross_attention") + and self.mid_block.has_cross_attention + ): sample = self.mid_block( sample, emb, @@ -1263,19 +1445,30 @@ def forward( if is_controlnet: sample = sample + mid_block_additional_residual + if return_post_mid_blocks: + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + down_block_res_samples = down_block_res_samples[ + : -len(upsample_block.resnets) + ] # if we have not reached the final block and need to forward the # upsample size, we do it here if not is_final_block and forward_upsample_size: upsample_size = down_block_res_samples[-1].shape[2:] - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + if ( + hasattr(upsample_block, "has_cross_attention") + and upsample_block.has_cross_attention + ): sample = upsample_block( hidden_states=sample, temb=emb,