diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 2cd2961a3a..426f9856fe 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -676,6 +676,8 @@ def __init__( progress: bool = True, copy_cache: bool = True, as_contiguous: bool = True, + hash_as_key: bool = False, + hash_func: Callable[..., bytes] = pickle_hashing, ) -> None: """ Args: @@ -695,19 +697,29 @@ def __init__( may set `copy=False` for better performance. as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous. it may help improve the performance of following logic. + hash_as_key: whether to compute hash value of input data as the key to save cache, + if key exists, avoid saving duplicated content. it can help save memory when + the dataset has duplicated items or augmented dataset. + hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached. + defaults to `monai.data.utils.pickle_hashing`. """ if not isinstance(transform, Compose): transform = Compose(transform) super().__init__(data=data, transform=transform) + self.set_num = cache_num # tracking the user-provided `cache_num` option + self.set_rate = cache_rate # tracking the user-provided `cache_rate` option self.progress = progress self.copy_cache = copy_cache self.as_contiguous = as_contiguous - self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) + self.hash_as_key = hash_as_key + self.hash_func = hash_func self.num_workers = num_workers if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) - self._cache: List = self._fill_cache() + self.cache_num = 0 + self._cache: Union[List, Dict] = [] + self.set_data(data) def set_data(self, data: Sequence): """ @@ -718,8 +730,21 @@ def set_data(self, data: Sequence): generated cache content. """ - self.data = data - self._cache = self._fill_cache() + + def _compute_cache(): + self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) + return self._fill_cache() + + if self.hash_as_key: + # only compute cache for the unique items of dataset + mapping = {self.hash_func(v): v for v in data} + self.data = list(mapping.values()) + cache_ = _compute_cache() + self._cache = dict(zip(list(mapping)[: self.cache_num], cache_)) + self.data = data + else: + self.data = data + self._cache = _compute_cache() def _fill_cache(self) -> List: if self.cache_num <= 0: @@ -754,14 +779,21 @@ def _load_cache_item(self, idx: int): return item def _transform(self, index: int): - if index % len(self) >= self.cache_num: # support negative index + index_: Any = index + if self.hash_as_key: + key = self.hash_func(self.data[index]) + if key in self._cache: + # if existing in cache, get the index + index_ = key # if using hash as cache keys, set the key + + if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index # no cache for this index, execute all the transforms directly - return super()._transform(index) + return super()._transform(index_) # load data from cache and execute from the first random transform start_run = False if self._cache is None: self._cache = self._fill_cache() - data = self._cache[index] + data = self._cache[index_] if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: @@ -862,10 +894,14 @@ def __init__( ) -> None: if shuffle: self.set_random_state(seed=seed) - data = copy(data) - self.randomize(data) self.shuffle = shuffle + self._start_pos: int = 0 + self._update_lock: threading.Lock = threading.Lock() + self._round: int = 1 + self._replace_done: bool = False + self._replace_mgr: Optional[threading.Thread] = None + super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous) if self._cache is None: self._cache = self._fill_cache() @@ -884,13 +920,6 @@ def __init__( self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) self._replacements: List[Any] = [None for _ in range(self._replace_num)] self._replace_data_idx: List[int] = list(range(self._replace_num)) - - self._start_pos: int = 0 - self._update_lock: threading.Lock = threading.Lock() - self._round: int = 1 - self._replace_done: bool = False - self._replace_mgr: Optional[threading.Thread] = None - self._compute_data_idx() def set_data(self, data: Sequence): diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index a742f5889a..7227f53e04 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -42,24 +42,12 @@ class TestCacheDataset(unittest.TestCase): def test_shape(self, transform, expected_shape): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: - nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz")) - nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz")) - test_data = [ - { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), - }, - { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), - }, - ] + test_data = [] + for i in ["1", "2"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True) data1 = dataset[0] data2 = dataset[1] @@ -68,9 +56,9 @@ def test_shape(self, transform, expected_shape): self.assertEqual(len(data3), 1) if transform is None: - self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data4["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz")) else: self.assertTupleEqual(data1["image"].shape, expected_shape) self.assertTupleEqual(data1["label"].shape, expected_shape) @@ -195,6 +183,46 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers): self.assertListEqual(expected, [y.item() for y in loader]) self.assertListEqual(expected, [y.item() for y in loader]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_hash_as_key(self, transform, expected_shape): + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + test_data = [] + for i in ["1", "2", "2", "3", "3"]: + for k in ["image", "label", "extra"]: + nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz")) + test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]}) + + dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True) + self.assertEqual(len(dataset), 5) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 3) + self.assertEqual(dataset.cache_num, 3) + data1 = dataset[0] + data2 = dataset[1] + data3 = dataset[-1] + # test slice indices + data4 = dataset[0:-1] + self.assertEqual(len(data4), 4) + + if transform is None: + self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz")) + self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz")) + self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz")) + else: + self.assertTupleEqual(data1["image"].shape, expected_shape) + self.assertTupleEqual(data2["label"].shape, expected_shape) + self.assertTupleEqual(data3["image"].shape, expected_shape) + for d in data4: + self.assertTupleEqual(d["image"].shape, expected_shape) + + test_data2 = test_data[:3] + dataset.set_data(data=test_data2) + self.assertEqual(len(dataset), 3) + # ensure no duplicated cache content + self.assertEqual(len(dataset._cache), 2) + self.assertEqual(dataset.cache_num, 2) + if __name__ == "__main__": unittest.main()