From ed502290d2ebd31f8cc25f4ae716b4912647b552 Mon Sep 17 00:00:00 2001 From: William Held Date: Wed, 21 Dec 2022 14:55:59 -0500 Subject: [PATCH 1/2] Width was typod as weight --- src/diffusers/models/attention.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 9fe6a8034c22..2a62fbbc9eb2 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -204,17 +204,17 @@ def forward( """ # 1. Input if self.is_input_continuous: - batch, channel, height, weight = hidden_states.shape + batch, channel, height, width = hidden_states.shape residual = hidden_states hidden_states = self.norm(hidden_states) if not self.use_linear_projection: hidden_states = self.proj_in(hidden_states) inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) else: inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = self.proj_in(hidden_states) elif self.is_input_vectorized: hidden_states = self.latent_image_embedding(hidden_states) @@ -232,13 +232,13 @@ def forward( if self.is_input_continuous: if not self.use_linear_projection: hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() ) hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() ) output = hidden_states + residual From adae7a45407ebd3b690707a564fb79cfd59961ad Mon Sep 17 00:00:00 2001 From: Helw150 Date: Wed, 21 Dec 2022 15:22:14 -0500 Subject: [PATCH 2/2] Run Black --- src/diffusers/models/attention.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 2a62fbbc9eb2..91c450d4a581 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -231,15 +231,11 @@ def forward( # 3. Output if self.is_input_continuous: if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = self.proj_out(hidden_states) else: hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() - ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() output = hidden_states + residual elif self.is_input_vectorized: