diff --git a/docs/requirements.txt b/docs/requirements.txt index ac9c7f38a7..8b4c00a62a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -19,7 +19,6 @@ sphinxcontrib-qthelp sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas -einops transformers mlflow tensorboardX diff --git a/docs/source/installation.md b/docs/source/installation.md index c431b3389a..50e32698d3 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,10 +174,10 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, transformers, mlflow, matplotlib, tensorboardX, tifffile, imagecodecs] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX`, `tifffile`, `imagecodecs`, respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/config/deviceconfig.py b/monai/config/deviceconfig.py index e542da14ab..b8fe4b63b9 100644 --- a/monai/config/deviceconfig.py +++ b/monai/config/deviceconfig.py @@ -72,7 +72,6 @@ def get_optional_config_values(): output["lmdb"] = get_package_version("lmdb") output["psutil"] = psutil_version output["pandas"] = get_package_version("pandas") - output["einops"] = get_package_version("einops") output["transformers"] = get_package_version("transformers") output["mlflow"] = get_package_version("mlflow") diff --git a/monai/networks/blocks/patchembedding.py b/monai/networks/blocks/patchembedding.py index 492e7bf236..e06a106fe4 100644 --- a/monai/networks/blocks/patchembedding.py +++ b/monai/networks/blocks/patchembedding.py @@ -18,10 +18,9 @@ import torch.nn as nn from monai.networks.layers import Conv -from monai.utils import ensure_tuple_rep, optional_import +from monai.utils import ensure_tuple_rep from monai.utils.module import look_up_option -Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"} @@ -70,7 +69,11 @@ def __init__( if hidden_size % num_heads != 0: raise ValueError("hidden size should be divisible by num_heads.") + if spatial_dims not in [2, 3]: + raise ValueError("spatial_dims should be 2 or 3.") + self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES) + self.permute_dims = get_permute_dims(spatial_dims) img_size = ensure_tuple_rep(img_size, spatial_dims) patch_size = ensure_tuple_rep(patch_size, spatial_dims) @@ -79,23 +82,19 @@ def __init__( raise ValueError("patch_size should be smaller than img_size.") if self.pos_embed == "perceptron" and m % p != 0: raise ValueError("patch_size should be divisible by img_size for perceptron.") - self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)]) - self.patch_dim = in_channels * np.prod(patch_size) + + img_by_patch = [im_d // p_d for im_d, p_d in zip(img_size, patch_size)] + self.n_patches = int(np.prod(img_by_patch)) + self.patch_dim = int(in_channels * np.prod(patch_size)) + self.reshape_spatial_dims = [x for z in zip(img_by_patch, patch_size) for x in z] self.patch_embeddings: nn.Module if self.pos_embed == "conv": self.patch_embeddings = Conv[Conv.CONV, spatial_dims]( in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size ) - elif self.pos_embed == "perceptron": - # for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)" - chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims] - from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars) - to_chars = f"b ({' '.join([c[0] for c in chars])}) ({' '.join([c[1] for c in chars])} c)" - axes_len = {f"p{i+1}": p for i, p in enumerate(patch_size)} - self.patch_embeddings = nn.Sequential( - Rearrange(f"{from_chars} -> {to_chars}", **axes_len), nn.Linear(self.patch_dim, hidden_size) - ) + else: + self.patch_embeddings = nn.Linear(self.patch_dim, hidden_size) self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size)) self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) self.dropout = nn.Dropout(dropout_rate) @@ -127,10 +126,26 @@ def norm_cdf(x): tensor.clamp_(min=a, max=b) return tensor + def _rearrange_input(self, x): + b, c = x.shape[:2] + reshape_size = [b, c] + self.reshape_spatial_dims + x = x.reshape(reshape_size) + x = x.permute(self.permute_dims) + return x.reshape([b, self.n_patches, self.patch_dim]) + def forward(self, x): + if self.pos_embed == "perceptron": + x = self._rearrange_input(x) x = self.patch_embeddings(x) if self.pos_embed == "conv": x = x.flatten(2).transpose(-1, -2) embeddings = x + self.position_embeddings embeddings = self.dropout(embeddings) return embeddings + + +def get_permute_dims(spatial_dims: int): + if spatial_dims == 2: + return (0, 2, 4, 3, 5, 1) + else: # spatial_dims == 3 + return (0, 2, 4, 6, 3, 5, 7, 1) diff --git a/monai/networks/blocks/selfattention.py b/monai/networks/blocks/selfattention.py index 932475b06c..4c81243437 100644 --- a/monai/networks/blocks/selfattention.py +++ b/monai/networks/blocks/selfattention.py @@ -12,10 +12,6 @@ import torch import torch.nn as nn -from monai.utils import optional_import - -einops, _ = optional_import("einops") - class SABlock(nn.Module): """ @@ -49,11 +45,15 @@ def __init__(self, hidden_size: int, num_heads: int, dropout_rate: float = 0.0) self.scale = self.head_dim ** -0.5 def forward(self, x): - q, k, v = einops.rearrange(self.qkv(x), "b h (qkv l d) -> qkv b l h d", qkv=3, l=self.num_heads) + b, h, in_feats = x.shape + l = self.num_heads + d = in_feats // l + x = self.qkv(x).reshape([b, h, 3, l, d]).permute(2, 0, 3, 1, 4) + q, k, v = x[0], x[1], x[2] att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1) att_mat = self.drop_weights(att_mat) x = torch.einsum("bhxy,bhyd->bhxd", att_mat, v) - x = einops.rearrange(x, "b h l d -> b l (h d)") + x = x.permute([0, 2, 1, 3]).reshape([b, h, l * d]) x = self.out_proj(x) x = self.drop_output(x) return x diff --git a/monai/networks/nets/vit.py b/monai/networks/nets/vit.py index 2707e5ad1d..8e499dd56a 100644 --- a/monai/networks/nets/vit.py +++ b/monai/networks/nets/vit.py @@ -99,7 +99,7 @@ def __init__( def forward(self, x): x = self.patch_embedding(x) - if self.classification: + if hasattr(self, "cls_token"): cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1) hidden_states_out = [] @@ -107,6 +107,6 @@ def forward(self, x): x = blk(x) hidden_states_out.append(x) x = self.norm(x) - if self.classification: + if hasattr(self, "classification_head"): x = self.classification_head(x[:, 0]) return x, hidden_states_out diff --git a/requirements-dev.txt b/requirements-dev.txt index f47eb14bbd..b52bd7c5e8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -37,7 +37,6 @@ imagecodecs; platform_system == "Linux" tifffile; platform_system == "Linux" pandas requests -einops transformers mlflow matplotlib!=3.5.0 diff --git a/setup.cfg b/setup.cfg index c58e683b12..d037834192 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,6 @@ all = tifffile imagecodecs pandas - einops transformers mlflow matplotlib @@ -82,8 +81,6 @@ imagecodecs = imagecodecs pandas = pandas -einops = - einops transformers = transformers mlflow = diff --git a/tests/test_patchembedding.py b/tests/test_patchembedding.py index 6c9ac78a99..b3ef0723fc 100644 --- a/tests/test_patchembedding.py +++ b/tests/test_patchembedding.py @@ -10,16 +10,12 @@ # limitations under the License. import unittest -from unittest import skipUnless import torch from parameterized import parameterized from monai.networks import eval_mode from monai.networks.blocks.patchembedding import PatchEmbeddingBlock -from monai.utils import optional_import - -einops, has_einops = optional_import("einops") TEST_CASE_PATCHEMBEDDINGBLOCK = [] for dropout_rate in (0.5,): @@ -51,7 +47,6 @@ class TestPatchEmbeddingBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_PATCHEMBEDDINGBLOCK) - @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = PatchEmbeddingBlock(**input_param) with eval_mode(net): diff --git a/tests/test_selfattention.py b/tests/test_selfattention.py index 559e86487b..e5bd80cb7d 100644 --- a/tests/test_selfattention.py +++ b/tests/test_selfattention.py @@ -10,7 +10,6 @@ # limitations under the License. import unittest -from unittest import skipUnless import numpy as np import torch @@ -18,9 +17,6 @@ from monai.networks import eval_mode from monai.networks.blocks.selfattention import SABlock -from monai.utils import optional_import - -einops, has_einops = optional_import("einops") TEST_CASE_SABLOCK = [] for dropout_rate in np.linspace(0, 1, 4): @@ -37,7 +33,6 @@ class TestResBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_SABLOCK) - @skipUnless(has_einops, "Requires einops") def test_shape(self, input_param, input_shape, expected_shape): net = SABlock(**input_param) with eval_mode(net): diff --git a/tests/test_vit.py b/tests/test_vit.py index cdf0888222..07af1b61f8 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -16,6 +16,7 @@ from monai.networks import eval_mode from monai.networks.nets.vit import ViT +from tests.utils import test_script_save TEST_CASE_Vit = [] for dropout_rate in [0.6]: @@ -133,6 +134,18 @@ def test_ill_arg(self): dropout_rate=0.3, ) + @parameterized.expand(TEST_CASE_Vit) + def test_script(self, input_param, input_shape, _): + net = ViT(**(input_param)) + net.eval() + with torch.no_grad(): + torch.jit.script(net) + + input_param_ = dict(input_param) + net = ViT(**(input_param_)) + test_data = torch.randn(input_shape) + test_script_save(net, test_data) + if __name__ == "__main__": unittest.main()