Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Reformating Unet#53

Merged
Warvito merged 32 commits into
mainfrom
31-fix-torchscript-error-in-latent-diffusion-models-unet-network
Nov 30, 2022
Merged

Reformating Unet#53
Warvito merged 32 commits into
mainfrom
31-fix-torchscript-error-in-latent-diffusion-models-unet-network

Conversation

@Warvito

@Warvito Warvito commented Nov 8, 2022

Copy link
Copy Markdown
Collaborator

Fix #31

@Warvito Warvito linked an issue Nov 8, 2022 that may be closed by this pull request
@Warvito

Warvito commented Nov 9, 2022

Copy link
Copy Markdown
Collaborator Author

The current version is increasing significantly the memory consumption. It is necessary to check.

@Warvito

Warvito commented Nov 9, 2022

Copy link
Copy Markdown
Collaborator Author

The current version is increasing significantly the memory consumption. It is necessary to check.

fixed

@Warvito

Warvito commented Nov 9, 2022

Copy link
Copy Markdown
Collaborator Author

@ericspod sorry, but I was not able to solve a few Torchscript problems, could you please help me or indicate someone that could help us?

Right now, I am stuck in two kind of errors:

  1. when adding 2 Tuples on:
    output_states += (hidden_states,)

I am getting

RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Tuple[()]'.
  
  aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Tuple[()]'.
  
  aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Tuple[()]'.
  
  aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Tuple[()]'.
  
  aten::add.t(t[] a, t[] b) -> (t[]):
  Could not match type Tuple[()] to List[t] in argument 'a': Cannot match List[t] to Tuple[()].
  
  aten::add.str(str a, str b) -> (str):
  Expected a value of type 'str' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.int(int a, int b) -> (int):
  Expected a value of type 'int' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.complex(complex a, complex b) -> (complex):
  Expected a value of type 'complex' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.float(float a, float b) -> (float):
  Expected a value of type 'float' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.int_complex(int a, complex b) -> (complex):
  Expected a value of type 'int' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.complex_int(complex a, int b) -> (complex):
  Expected a value of type 'complex' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.float_complex(float a, complex b) -> (complex):
  Expected a value of type 'float' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.complex_float(complex a, float b) -> (complex):
  Expected a value of type 'complex' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.int_float(int a, float b) -> (float):
  Expected a value of type 'int' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add.float_int(float a, int b) -> (float):
  Expected a value of type 'float' for argument 'a' but instead found type 'Tuple[()]'.
  
  aten::add(Scalar a, Scalar b) -> (Scalar):
  Expected a value of type 'number' for argument 'a' but instead found type 'Tuple[()]'.
  
  add(float a, Tensor b) -> (Tensor):
  Expected a value of type 'float' for argument 'a' but instead found type 'Tuple[()]'.
  
  add(int a, Tensor b) -> (Tensor):
  Expected a value of type 'int' for argument 'a' but instead found type 'Tuple[()]'.
  
  add(complex a, Tensor b) -> (Tensor):
  Expected a value of type 'complex' for argument 'a' but instead found type 'Tuple[()]'.

The original call is:
  File "/media/walter/Storage/Projects/GenerativeModels/generative/networks/nets/diffusion_model_unet.py", line 680
        for resnet in self.resnets:
            hidden_states = resnet(hidden_states, temb)
            output_states += (hidden_states,)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

I found a few mention about this problem (https://discuss.pytorch.org/t/torchscript-jit-cant-handle-tuple-tensor-tensor-tuple-tensor-statement/115838), but I could not solve it.

  1. the second type of error is
RuntimeError: Can't redefine method: forward on class: __torch__.generative.networks.nets.diffusion_model_unet.DownBlock (of Python compilation unit at: 0x4a94600)

This one it does not report the line of the error. =/

Both errors should be replicated when running the tests

@marksgraham

Copy link
Copy Markdown
Collaborator

Hi Walter,

The solution to the adding two tuples problem is to use lists instead, and theappend() method, e.g.:

    output_states = []

    for resnet in self.resnets:
        hidden_states = resnet(hidden_states, temb)
        output_states.append(hidden_states)


    if self.downsampler is not None:
        hidden_states = self.downsampler(hidden_states)
        output_states.append(hidden_states)

There are quite a few places in the code this needs replacing.

I'm still getting other torchscript errors, though. A big one is you can't use einops.rearrange because it takes a variable number of arguments (**axes_length). I've looked at how MONAI solves this, here they use the Rearrange layer from eninops. The problem with this approach is these layers need to be in the class init, so to replace something like:

x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)

with a layer we would need to know the values of h,w during class init.

Another option would just be to remove the einops dependency and replace it with code that is messier but will be torchscript compatible, as MONAI trialled here

@Warvito

Warvito commented Nov 24, 2022

Copy link
Copy Markdown
Collaborator Author

Almost there... =/
Getting a problem with appending/extending this list

Error
Traceback (most recent call last):
  File "/media/walter/Storage/Projects/GenerativeModels/tests/test_diffusion_model_unet.py", line 156, in test_script_conditioned_2d_models
    test_script_save(
  File "/media/walter/Storage/Projects/GenerativeModels/MONAI/tests/utils.py", line 701, in test_script_save
    convert_to_torchscript(
  File "/media/walter/Storage/Projects/GenerativeModels/MONAI/monai/networks/utils.py", line 593, in convert_to_torchscript
    script_module = torch.jit.script(model, **kwargs)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/torch/jit/_script.py", line 942, in script
    return torch.jit._recursive.create_script_module(
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 391, in create_script_module
    return create_script_module_impl(nn_module, concrete_type, stubs_fn)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 452, in create_script_module_impl
    create_methods_and_properties_from_stubs(concrete_type, method_stubs, property_stubs)
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/torch/jit/_recursive.py", line 335, in create_methods_and_properties_from_stubs
    concrete_type._create_methods_and_properties(property_defs, property_rcbs, method_defs, method_rcbs, method_defaults)
RuntimeError: 

aten::extend.t(t[](a!) self, t[] other) -> ():
Could not match type Any to List[t] in argument 'other': Cannot match List[t] to Any.
:
  File "/media/walter/Storage/Projects/GenerativeModels/generative/networks/nets/diffusion_model_unet.py", line 1505
        for downsample_block in self.down_blocks:
            h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
            down_block_res_samples.extend(res_samples)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
        # 4. mid

@ericspod

Copy link
Copy Markdown
Member

Torchscript will use type annotations so down_block_res_samples should be given a type like List[nn.Module] and res_samples should be assigned the same. This should let Torchscript deduce what static type the list would have in a statically typed environment, eg. in C++.

@Warvito

Warvito commented Nov 25, 2022

Copy link
Copy Markdown
Collaborator Author

Fixed all torchscript errors, but the performance at the tutorials is worse than before, Trying to find the source of the problem =/

@Warvito

Warvito commented Nov 25, 2022

Copy link
Copy Markdown
Collaborator Author

Surprisingly, zero_module function makes a big change (not included in the diffuser implementation)

@Warvito

Warvito commented Nov 25, 2022

Copy link
Copy Markdown
Collaborator Author

I tried to use a version with the attentions blocks using the Rearrange layer and the MONAI's SAB here


class AttentionBlock(nn.Module):
    """
    An attention block that allows spatial positions to attend to each other. Uses three q, k, v linear layers to
    compute attention.

    Args:
        spatial_dims: number of spatial dimensions.
        num_channels: number of channels in the input and output.
        num_head_channels: number of channels in each attention head.
        norm_num_groups: number of groups to use for group norm.
        norm_eps: epsilon value to use for group norm.
    """

    def __init__(
        self,
        spatial_dims: int,
        num_channels: int,
        num_head_channels: Optional[int] = None,
        norm_num_groups: int = 32,
        norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.spatial_dims = spatial_dims
        self.num_channels = num_channels

        self.num_heads = num_channels // num_head_channels if num_head_channels is not None else 1
        self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=num_channels, eps=norm_eps, affine=True)

        self.attention = SABlock(hidden_size=num_channels, num_heads=self.num_heads, qkv_bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x

        batch = channel = height = width = depth = -1
        if self.spatial_dims == 2:
            batch, channel, height, width = x.shape
        if self.spatial_dims == 3:
            batch, channel, height, width, depth = x.shape

        x = self.norm(x)

        if self.spatial_dims == 2:
            x = x.view(batch, channel, height * width).transpose(1, 2)
        if self.spatial_dims == 3:
            x = x.view(batch, channel, height * width * depth).transpose(1, 2)

        x = self.attention(x)

        if self.spatial_dims == 2:
            x = x.transpose(-1, -2).reshape(batch, channel, height, width)
        if self.spatial_dims == 3:
            x = x.transpose(-1, -2).reshape(batch, channel, height, width, depth)

        return x + residual

and


class CrossAttention(nn.Module):
    """
    A cross attention layer.

    Args:
        query_dim: number of channels in the query.
        cross_attention_dim: number of channels in the context.
        num_attention_heads: number of heads to use for multi-head attention.
        num_head_channels: number of channels in each head.
        dropout: dropout probability to use.
    """

    def __init__(
        self,
        query_dim: int,
        cross_attention_dim: Optional[int] = None,
        num_attention_heads: int = 8,
        num_head_channels: int = 64,
        dropout: float = 0.0,
    ) -> None:
        super().__init__()
        inner_dim = num_head_channels * num_attention_heads
        cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim

        self.scale = num_head_channels**-0.5
        self.heads = num_attention_heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False)

        self.input_rearrange = Rearrange("b n (h d) -> (b h) n d", h=num_attention_heads)
        self.out_rearrange = Rearrange("(b h) n d -> b n (h d)", h=num_attention_heads)

        self.out_proj = nn.Linear(inner_dim, query_dim)
        self.drop_output = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
        query = self.to_q(x)
        context = context if context is not None else x
        key = self.to_k(context)
        value = self.to_v(context)

        query = self.input_rearrange(query)
        key = self.input_rearrange(key)
        value = self.input_rearrange(value)

        sim = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
        attn = sim.softmax(dim=-1)
        out = torch.einsum("b i j, b j d -> b i d", attn, value)
        out = self.out_rearrange(out)

        out = self.out_proj(out)
        out = self.drop_output(out)
        return out

but I am getting this torchscript error

RuntimeError: 
Module 'Rearrange' has no attribute '_recipe' (This attribute exists on the Python module, but we failed to convert Python type: 'einops.einops.TransformRecipe' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type TransformRecipe.. Its type was inferred; try adding a type annotation for the attribute.):
  File "/media/walter/Storage/Projects/GenerativeModels/venv/lib/python3.8/site-packages/einops/layers/torch.py", line 14
    def forward(self, input):
        return apply_for_scriptable_torch(self._recipe, input, reduction_type='rearrange')
                                          ~~~~~~~~~~~~ <--- HERE

So I am adopting the version without einops again.

@Warvito Warvito changed the title [WIP] Reformating Unet Reformating Unet Nov 26, 2022
@Warvito Warvito marked this pull request as ready for review November 26, 2022 07:27
@ericspod

Copy link
Copy Markdown
Member

Could that be related to the einops version? I thought this was Torchscript compatible by now, but if your implementation works we stick with it.

@Warvito

Warvito commented Nov 28, 2022

Copy link
Copy Markdown
Collaborator Author

Could that be related to the einops version? I thought this was Torchscript compatible by now, but if your implementation works we stick with it.

Yes, it is compatible with Torchscript now and ready for Review. Let's use it for version 0.1.

However, there is an attention block on it that it could make use of the SAB block from MONAI core (https://github.com/Project-MONAI/MONAI/blob/0c2fbac6eb55a546886921d9add66be9688d9775/monai/networks/blocks/selfattention.py#L20). Maybe we can use it in the future. For now, when I tried with it (using einops version 0.6.0), I got the mentioned error.

@marksgraham marksgraham self-requested a review November 28, 2022 22:44

@marksgraham marksgraham left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Warvito

This looks really nice - the new unet implementation is very neat! Comments are very minor, to do with unused args and variables.

I noticed you've increased the time estimates in all the tutorials - is the new model slower? If so, do you have any idea why?

Comment thread generative/networks/nets/diffusion_model_unet.py Outdated
Comment thread generative/networks/nets/diffusion_model_unet.py
Comment thread generative/networks/nets/diffusion_model_unet.py
Comment thread generative/networks/nets/diffusion_model_unet.py
Comment thread generative/networks/nets/diffusion_model_unet.py
Comment thread generative/networks/nets/diffusion_model_unet.py Outdated
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
Signed-off-by: Walter Hugo Lopez Pinaya <ianonimato@hotmail.com>
@Warvito

Warvito commented Nov 30, 2022

Copy link
Copy Markdown
Collaborator Author

I noticed you've increased the time estimates in all the tutorials - is the new model slower? If so, do you have any idea why?

In the 2d_ddpm_tutorial.ipynb I decided to make the network a little bigger to get better samples in the end. I decided to increase the number of epochs too to avoid getting hands with 4 fingers in the final sampling =/

@Warvito Warvito merged commit c294e7f into main Nov 30, 2022
@Warvito Warvito deleted the 31-fix-torchscript-error-in-latent-diffusion-models-unet-network branch December 4, 2022 10:55
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Fix Torchscript error in latent diffusion models unet network

3 participants