Skip to content

Commit

Permalink
apply suggestions from review; update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w authored and yiyixuxu committed Oct 21, 2024
1 parent e96d5ad commit afc60a0
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ image = pipe(
image.save("sd3_hello_world.png")
```

**Note:** Stable Diffusion 3.5 can also be run using the SD3 pipeline, and all mentioned optimizations and techniques apply to it as well. In total there are three official models in the SD3 family:
- [`stabilityai/stable-diffusion-3-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers)
- [`stabilityai/stable-diffusion-3.5-medium-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium-diffusers)
- [`stabilityai/stable-diffusion-3.5-large-diffusers`](https://huggingface.co/stabilityai/stable-diffusion-3.5-large-diffusers)

## Memory Optimisations for SD3

SD3 uses three text encoders, one if which is the very large T5-XXL model. This makes it challenging to run the model on GPUs with less than 24GB of VRAM, even when using `fp16` precision. The following section outlines a few memory optimizations in Diffusers that make it easier to run SD3 on low resource hardware.
Expand Down
22 changes: 12 additions & 10 deletions scripts/convert_sd3_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def swap_scale_shift(weight, dim):


def convert_sd3_transformer_checkpoint_to_diffusers(
original_state_dict, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm
original_state_dict, num_layers, caption_projection_dim, dual_attention_layers, has_qk_norm
):
converted_state_dict = {}

Expand Down Expand Up @@ -142,7 +142,7 @@ def convert_sd3_transformer_checkpoint_to_diffusers(
)

# attn2
if i in add_attn2_layers:
if i in dual_attention_layers:
# Q, K, V
sample_q2, sample_k2, sample_v2 = torch.chunk(
original_state_dict.pop(f"joint_blocks.{i}.x_block.attn2.qkv.weight"), 3, dim=0
Expand Down Expand Up @@ -244,14 +244,14 @@ def is_vae_in_checkpoint(original_state_dict):
)


def get_add_attn2_layers(state_dict):
add_attn2_layers = []
def get_attn2_layers(state_dict):
attn2_layers = []
for key in state_dict.keys():
if "attn2." in key:
# Extract the layer number from the key
layer_num = int(key.split(".")[1])
add_attn2_layers.append(layer_num)
return tuple(sorted(set(add_attn2_layers)))
attn2_layers.append(layer_num)
return tuple(sorted(set(attn2_layers)))


def get_pos_embed_max_size(state_dict):
Expand Down Expand Up @@ -284,14 +284,16 @@ def main(args):
raise ValueError(f"Unsupported dtype: {args.dtype}")

if dtype != original_dtype:
print(f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution.")
print(
f"Checkpoint dtype {original_dtype} does not match requested dtype {dtype}. This can lead to unexpected results, proceed with caution."
)

num_layers = list(set(int(k.split(".", 2)[1]) for k in original_ckpt if "joint_blocks" in k))[-1] + 1 # noqa: C401

caption_projection_dim = get_caption_projection_dim(original_ckpt)

# () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
add_attn2_layers = get_add_attn2_layers(original_ckpt)
attn2_layers = get_attn2_layers(original_ckpt)

# sd3.5 use qk norm("rms_norm")
has_qk_norm = any("ln_q" in key for key in original_ckpt.keys())
Expand All @@ -300,7 +302,7 @@ def main(args):
pos_embed_max_size = get_pos_embed_max_size(original_ckpt)

converted_transformer_state_dict = convert_sd3_transformer_checkpoint_to_diffusers(
original_ckpt, num_layers, caption_projection_dim, add_attn2_layers, has_qk_norm
original_ckpt, num_layers, caption_projection_dim, attn2_layers, has_qk_norm
)

with CTX():
Expand All @@ -314,7 +316,7 @@ def main(args):
num_attention_heads=num_layers,
pos_embed_max_size=pos_embed_max_size,
qk_norm="rms_norm" if has_qk_norm else None,
add_attn2_layers=add_attn2_layers,
dual_attention_layers=attn2_layers,
)
if is_accelerate_available():
load_model_dict_into_meta(transformer, converted_transformer_state_dict)
Expand Down
20 changes: 14 additions & 6 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,21 @@ class JointTransformerBlock(nn.Module):
"""

def __init__(
self, dim, num_attention_heads, attention_head_dim, context_pre_only=False, qk_norm=None, add_attn2=False
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
context_pre_only: bool = False,
qk_norm: Optional[str] = None,
use_dual_attention: bool = False,
):
super().__init__()

self.add_attn2 = add_attn2
self.use_dual_attention = use_dual_attention
self.context_pre_only = context_pre_only
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"

if add_attn2:
if use_dual_attention:
self.norm1 = SD35AdaLayerNormZeroX(dim)
else:
self.norm1 = AdaLayerNormZero(dim)
Expand All @@ -124,12 +130,14 @@ def __init__(
raise ValueError(
f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
)

if hasattr(F, "scaled_dot_product_attention"):
processor = JointAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)

self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
Expand All @@ -144,7 +152,7 @@ def __init__(
eps=1e-6,
)

if add_attn2:
if use_dual_attention:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=None,
Expand Down Expand Up @@ -182,7 +190,7 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
):
if self.add_attn2:
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
hidden_states, emb=temb
)
Expand All @@ -205,7 +213,7 @@ def forward(
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.add_attn2:
if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2
Expand Down
16 changes: 8 additions & 8 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def forward(self, inputs: torch.Tensor) -> torch.Tensor:

class SD35AdaLayerNormZeroX(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Norm layer adaptive layer norm zero (AdaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""

def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
super().__init__()

self.silu = nn.SiLU()
Expand All @@ -118,17 +118,17 @@ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):

def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, ...]:
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
9, dim=1
)
normed_x = self.norm(x)
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
x2 = normed_x * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, x2, gate_msa2
norm_hidden_states = self.norm(hidden_states)
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2


class AdaLayerNormZero(nn.Module):
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def __init__(
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
add_attn2_layers: Tuple[int, ...] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
dual_attention_layers: Tuple[
int, ...
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
qk_norm: Optional[str] = None,
):
super().__init__()
Expand Down Expand Up @@ -100,7 +102,7 @@ def __init__(
attention_head_dim=self.config.attention_head_dim,
context_pre_only=i == num_layers - 1,
qk_norm=qk_norm,
add_attn2=True if i in add_attn2_layers else False,
use_dual_attention=True if i in dual_attention_layers else False,
)
for i in range(self.config.num_layers)
]
Expand Down
59 changes: 59 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,65 @@ def prepare_init_args_and_inputs_for_common(self):
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
"pos_embed_max_size": 96,
"dual_attention_layers": (),
"qk_norm": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass


class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"

@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154

hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)

return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}

@property
def input_shape(self):
return (4, 32, 32)

@property
def output_shape(self):
return (4, 32, 32)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 2,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
"pos_embed_max_size": 96,
"dual_attention_layers": (0,),
"qk_norm": "rms_norm",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Expand Down

0 comments on commit afc60a0

Please sign in to comment.