Skip to content
Merged
8 changes: 8 additions & 0 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class MedNISTDataset(Randomizable, CacheDataset):
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand All @@ -83,6 +85,7 @@ def __init__(
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand Down Expand Up @@ -120,6 +123,7 @@ def __init__(
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
as_contiguous=as_contiguous,
)

def randomize(self, data: np.ndarray) -> None:
Expand Down Expand Up @@ -205,6 +209,8 @@ class DecathlonDataset(Randomizable, CacheDataset):
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand Down Expand Up @@ -271,6 +277,7 @@ def __init__(
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand Down Expand Up @@ -322,6 +329,7 @@ def __init__(
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
as_contiguous=as_contiguous,
)

def get_indices(self) -> np.ndarray:
Expand Down
4 changes: 4 additions & 0 deletions monai/apps/pathology/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ class SmartCachePatchWSIDataset(SmartCacheDataset):
default to `True`. if the random transforms don't modify the cache content
or every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.

"""

Expand All @@ -144,6 +146,7 @@ def __init__(
num_replace_workers: Optional[int] = None,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
):
patch_wsi_dataset = PatchWSIDataset(
data=data,
Expand All @@ -163,6 +166,7 @@ def __init__(
progress=progress,
shuffle=False,
copy_cache=copy_cache,
as_contiguous=as_contiguous,
)


Expand Down
15 changes: 13 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.utils.data import Subset

from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform, convert_to_contiguous
from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

Expand Down Expand Up @@ -671,6 +671,7 @@ def __init__(
num_workers: Optional[int] = None,
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
) -> None:
"""
Args:
Expand All @@ -688,12 +689,16 @@ def __init__(
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
self.num_workers = num_workers
if self.num_workers is not None:
Expand Down Expand Up @@ -740,6 +745,8 @@ def _load_cache_item(self, idx: int):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = apply_transform(_xform, item)
if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item

def _transform(self, index: int):
Expand Down Expand Up @@ -829,6 +836,9 @@ class SmartCacheDataset(Randomizable, CacheDataset):
default to `True`. if the random transforms don't modify the cache content
or every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.

"""

def __init__(
Expand All @@ -844,14 +854,15 @@ def __init__(
shuffle: bool = True,
seed: int = 0,
copy_cache: bool = True,
as_contiguous: bool = True,
) -> None:
if shuffle:
self.set_random_state(seed=seed)
data = copy(data)
self.randomize(data)
self.shuffle = shuffle

super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache)
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous)
if self._cache is None:
self._cache = self._fill_cache()
if self.cache_num >= len(data):
Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@
compute_divisible_spatial_size,
convert_inverse_interp_mode,
convert_pad_mode,
convert_to_contiguous,
copypaste_arrays,
create_control_grid,
create_grid,
Expand Down Expand Up @@ -552,6 +553,7 @@
)
from .utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
clip,
concatenate,
cumsum,
Expand Down
21 changes: 21 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from monai.transforms.transform import MapTransform, Transform, apply_transform
from monai.transforms.utils_pytorch_numpy_unification import (
any_np_pt,
ascontiguousarray,
cumsum,
isfinite,
nonzero,
Expand Down Expand Up @@ -98,6 +99,7 @@
"get_transform_backends",
"print_transform_backends",
"convert_pad_mode",
"convert_to_contiguous",
]


Expand Down Expand Up @@ -1496,5 +1498,24 @@ def convert_pad_mode(dst: NdarrayOrTensor, mode: Union[NumpyPadMode, PytorchPadM
raise ValueError(f"unsupported data type: {type(dst)}.")


def convert_to_contiguous(data, **kwargs):
"""
Check and ensure the numpy array or PyTorch Tensor in data to be contuguous in memory.

Args:
data: input data to convert, will recursively convert the numpy array or PyTorch Tensor in dict and sequence.
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.

"""
if isinstance(data, (np.ndarray, torch.Tensor)):
return ascontiguousarray(data, **kwargs)
if isinstance(data, dict):
return {k: convert_to_contiguous(v, **kwargs) for k, v in data.items()}
if isinstance(data, (list, tuple)):
return [convert_to_contiguous(i, **kwargs) for i in data]
return data


if __name__ == "__main__":
print_transform_backends()
15 changes: 15 additions & 0 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"searchsorted",
"repeat",
"isnan",
"ascontiguousarray",
]


Expand Down Expand Up @@ -321,3 +322,17 @@ def isnan(x: NdarrayOrTensor):
if isinstance(x, np.ndarray):
return np.isnan(x)
return torch.isnan(x)


def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
"""`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).

Args:
x: array/tensor
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.

"""
if isinstance(x, np.ndarray):
return np.ascontiguousarray(x)
return x.contiguous(**kwargs)
Binary file added testdata.nrrd
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_cachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_shape(self, transform, expected_shape):
"extra": os.path.join(tempdir, "test_extra2.nii.gz"),
},
]
dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5)
dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True)
data1 = dataset[0]
data2 = dataset[1]
data3 = dataset[0:-1]
Expand Down