Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ sphinxcontrib-qthelp
sphinxcontrib-serializinghtml
sphinx-autodoc-typehints==1.11.1
pandas
einops
transformers
mlflow
tensorboardX
Expand Down
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`,
`tifffile`, `imagecodecs`, respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
1 change: 0 additions & 1 deletion monai/config/deviceconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def get_optional_config_values():
output["lmdb"] = get_package_version("lmdb")
output["psutil"] = psutil_version
output["pandas"] = get_package_version("pandas")
output["einops"] = get_package_version("einops")
output["transformers"] = get_package_version("transformers")
output["mlflow"] = get_package_version("mlflow")

Expand Down
41 changes: 28 additions & 13 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import torch.nn as nn

from monai.networks.layers import Conv
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}


Expand Down Expand Up @@ -70,7 +69,11 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError("hidden size should be divisible by num_heads.")

if spatial_dims not in [2, 3]:
raise ValueError("spatial_dims should be 2 or 3.")

self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)
self.permute_dims = get_permute_dims(spatial_dims)

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
Expand All @@ -79,23 +82,19 @@ def __init__(
raise ValueError("patch_size should be smaller than img_size.")
if self.pos_embed == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = in_channels * np.prod(patch_size)

img_by_patch = [im_d // p_d for im_d, p_d in zip(img_size, patch_size)]
self.n_patches = int(np.prod(img_by_patch))
self.patch_dim = int(in_channels * np.prod(patch_size))
self.reshape_spatial_dims = [x for z in zip(img_by_patch, patch_size) for x in z]

self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.pos_embed == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)"
axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)}
self.patch_embeddings = nn.Sequential(
Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size)
)
else:
self.patch_embeddings = nn.Linear(self.patch_dim, hidden_size)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
Expand Down Expand Up @@ -127,10 +126,26 @@ def norm_cdf(x):
tensor.clamp_(min=a, max=b)
return tensor

def _rearrange_input(self, x):
b, c = x.shape[:2]
reshape_size = [b, c] + self.reshape_spatial_dims
x = x.reshape(reshape_size)
x = x.permute(self.permute_dims)
return x.reshape([b, self.n_patches, self.patch_dim])

def forward(self, x):
if self.pos_embed == "perceptron":
x = self._rearrange_input(x)
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
return embeddings


def get_permute_dims(spatial_dims: int):
if spatial_dims == 2:
return (0, 2, 4, 3, 5, 1)
else: # spatial_dims == 3
return (0, 2, 4, 6, 3, 5, 7, 1)
12 changes: 6 additions & 6 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@
import torch
import torch.nn as nn

from monai.utils import optional_import

einops, _ = optional_import("einops")


class SABlock(nn.Module):
"""
Expand Down Expand Up @@ -49,11 +45,15 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0)
self.scale = self.head_dim ** -0.5

def forward(self, x):
q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads)
b, h, in_feats = x.shape
l = self.num_heads
d = in_feats // l
x = self.qkv(x).reshape([b, h, 3, l, d]).permute(2, 0, 3, 1, 4)
q, k, v = x[0], x[1], x[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = self.drop_weights(att_mat)
x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v)
x = einops.rearrange(x, "b h l d -> b l (h d)")
x = x.permute([0, 2, 1, 3]).reshape([b, h, l * d])
x = self.out_proj(x)
x = self.drop_output(x)
return x
4 changes: 2 additions & 2 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ def __init__(

def forward(self, x):
x = self.patch_embedding(x)
if self.classification:
if hasattr(self, "cls_token"):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
if self.classification:
if hasattr(self, "classification_head"):
x = self.classification_head(x[:, 0])
return x, hidden_states_out
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ imagecodecs; platform_system == "Linux"
tifffile; platform_system == "Linux"
pandas
requests
einops
transformers
mlflow
matplotlib!=3.5.0
Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ all =
tifffile
imagecodecs
pandas
einops
transformers
mlflow
matplotlib
Expand Down Expand Up @@ -82,8 +81,6 @@ imagecodecs =
imagecodecs
pandas =
pandas
einops =
einops
transformers =
transformers
mlflow =
Expand Down
5 changes: 0 additions & 5 deletions tests/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,12 @@
# limitations under the License.

import unittest
from unittest import skipUnless

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.utils import optional_import

einops, has_einops = optional_import("einops")

TEST_CASE_PATCHEMBEDDINGBLOCK = []
for dropout_rate in (0.5,):
Expand Down Expand Up @@ -51,7 +47,6 @@

class TestPatchEmbeddingBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
net = PatchEmbeddingBlock(**input_param)
with eval_mode(net):
Expand Down
5 changes: 0 additions & 5 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,13 @@
# limitations under the License.

import unittest
from unittest import skipUnless

import numpy as np
import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.blocks.selfattention import SABlock
from monai.utils import optional_import

einops, has_einops = optional_import("einops")

TEST_CASE_SABLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
Expand All @@ -37,7 +33,6 @@

class TestResBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_SABLOCK)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
net = SABlock(**input_param)
with eval_mode(net):
Expand Down
13 changes: 13 additions & 0 deletions tests/test_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.networks import eval_mode
from monai.networks.nets.vit import ViT
from tests.utils import test_script_save

TEST_CASE_Vit = []
for dropout_rate in [0.6]:
Expand Down Expand Up @@ -133,6 +134,18 @@ def test_ill_arg(self):
dropout_rate=0.3,
)

@parameterized.expand(TEST_CASE_Vit)
def test_script(self, input_param, input_shape, _):
net = ViT(**(input_param))
net.eval()
with torch.no_grad():
torch.jit.script(net)

input_param_ = dict(input_param)
net = ViT(**(input_param_))
test_data = torch.randn(input_shape)
test_script_save(net, test_data)


if __name__ == "__main__":
unittest.main()