From 64faa84b53c4b737a147ce2344bb25e03e2bf8b1 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 24 Jun 2021 16:58:10 +0800 Subject: [PATCH 1/8] [DLMED] add dynamic data list in CacheDataset Signed-off-by: Nic Ma --- monai/data/dataset.py | 26 ++++++++++++++++++ tests/test_cachedataset.py | 27 ++++++++++++++++++- tests/test_smartcachedataset.py | 48 +++++++++++++++++++++++++++++++-- 3 files changed, 98 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 74b9726081..b4e64f05be 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -555,6 +555,18 @@ def __init__( self.num_workers = max(int(self.num_workers), 1) self._cache: List = self._fill_cache() + def update_data(self, data: Sequence): + """ + Update the input data and re-run deterministic transforms to generate cache content. + + Note: should call this func after an entire epoch and must set `persisten_workers=False` + in PyTorch DataLoader, because it needs to create new worker processes based on new + generated cache content. + + """ + self.data = data + self._cache = self._fill_cache() + def _fill_cache(self) -> List: if self.cache_num <= 0: return [] @@ -709,6 +721,18 @@ def __init__( self._compute_data_idx() + def update_data(self, data: Sequence): + """ + Update the input data and re-run deterministic transforms to generate cache content. + + Note: should call `shutdown()` before calling this func. + + """ + if self.is_started(): + warnings.warn("SmartCacheDataset is not shutdown yet, shutdown it directly.") + self.shutdown() + return super().update_data(data) + def randomize(self, data: Sequence) -> None: try: self.R.shuffle(data) @@ -796,6 +820,8 @@ def _try_shutdown(self): with self._update_lock: if self._replace_done: self._round = 0 + self._start_pos = 0 + self._compute_data_idx() self._replace_done = False return True return False diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 34524ba5e8..055e07884c 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.data import CacheDataset, DataLoader, PersistentDataset, SmartCacheDataset -from monai.transforms import Compose, LoadImaged, ThreadUnsafe, Transform +from monai.transforms import Compose, Lambda, LoadImaged, ThreadUnsafe, Transform from monai.utils import get_torch_version_tuple TEST_CASE_1 = [Compose([LoadImaged(keys=["image", "label", "extra"])]), (128, 128, 128)] @@ -81,6 +81,31 @@ def test_shape(self, transform, expected_shape): for d in data3: self.assertTupleEqual(d["image"].shape, expected_shape) + def test_update_data(self): + data_list1 = list(range(10)) + + transform = Lambda(func=lambda x: np.array([x * 10])) + + dataset = CacheDataset( + data=data_list1, + transform=transform, + cache_rate=1.0, + num_workers=4, + progress=True, + ) + + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1, persistent_workers=False) + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10]], d) + + # update the datalist and fill the cache content + data_list2 = list(range(-10, 0)) + dataset.update_data(data=data_list2) + # rerun with updated cache content + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list2[i] * 10]], d) + class _StatefulTransform(Transform, ThreadUnsafe): """ diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 1499854c56..cd2daa7c9c 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -11,6 +11,7 @@ import copy import os +import sys import tempfile import unittest @@ -18,8 +19,8 @@ import numpy as np from parameterized import parameterized -from monai.data import SmartCacheDataset -from monai.transforms import Compose, LoadImaged +from monai.data import SmartCacheDataset, DataLoader +from monai.transforms import Compose, Lambda, LoadImaged TEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=["image", "label", "extra"])])] @@ -126,6 +127,49 @@ def test_shuffle(self): dataset.shutdown() + def test_update_data(self): + data_list1 = list(range(10)) + + transform = Lambda(func=lambda x: np.array([x * 10])) + + dataset = SmartCacheDataset( + data=data_list1, + transform=transform, + cache_rate=0.5, + replace_rate=0.4, + num_init_workers=4, + num_replace_workers=2, + shuffle=False, + progress=True, + ) + + num_workers = 2 if sys.platform == "linux" else 0 + dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1, persistent_workers=False) + + dataset.start() + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i] * 10]], d) + # replace cache content, move forward 2(5 * 0.4) items + dataset.update_cache() + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list1[i + 2] * 10]], d) + # shutdown to update data + dataset.shutdown() + # update the datalist and fill the cache content + data_list2 = list(range(-10, 0)) + dataset.update_data(data=data_list2) + # restart the dataset + dataset.start() + # rerun with updated cache content + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list2[i] * 10]], d) + # replace cache content, move forward 2(5 * 0.4) items + dataset.update_cache() + for i, d in enumerate(dataloader): + np.testing.assert_allclose([[data_list2[i + 2] * 10]], d) + # finally shutdown the dataset + dataset.shutdown() + if __name__ == "__main__": unittest.main() From 1855d65e2b624dcd2d7fa7199ef72268dbdaf4d3 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 24 Jun 2021 17:28:30 +0800 Subject: [PATCH 2/8] [DLMED] add support to PersistentDataset Signed-off-by: Nic Ma --- monai/data/dataset.py | 20 ++++++++++++++++++++ tests/test_persistentdataset.py | 20 ++++++++++++++++++++ tests/test_smartcachedataset.py | 2 +- 3 files changed, 41 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index b4e64f05be..785388f2c6 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -130,6 +130,9 @@ class PersistentDataset(Dataset): Subsequent uses of a dataset directly read pre-processed results from `cache_dir` followed by applying the random dependant parts of transform processing. + If want to update the input data and recompute cache content during training, just call `update_data()` + after several epochs. + Note: The input data must be a list of file paths and will hash them as cache keys. @@ -173,6 +176,17 @@ def __init__( if not self.cache_dir.is_dir(): raise ValueError("cache_dir must be a directory.") + def update_data(self, data: Sequence): + """ + Update the input data and delete all the out-dated cache content. + + """ + self.data = data + + if self.cache_dir.exists(): + shutil.rmtree(self.cache_dir) + self.cache_dir.mkdir(parents=True, exist_ok=True) + def _pre_transform(self, item_transformed): """ Process the data from original state up to the first random element. @@ -515,6 +529,9 @@ class CacheDataset(Dataset): ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached. + If want to update the input data and recompute cache content during training, just call `update_data()` + after several epochs, note that it requires `persistent_workers=False` in the PyTorch DataLoader. + Note: `CacheDataset` executes non-random transforms and prepares cache content in the main process before the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process @@ -651,6 +668,9 @@ class SmartCacheDataset(Randomizable, CacheDataset): 3. Call `update_cache()` before every epoch to replace training items. 4. Call `shutdown()` when training ends. + If want to update the input data and recompute cache content during training, just call + `shutdown()` and `update_data()` to stop and update, then call `start()` to restart. + Note: This replacement will not work for below cases: 1. Set the `multiprocessing_context` of DataLoader to `spawn`. diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 09488b1214..5577c742ff 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -123,6 +123,26 @@ def test_shape(self, transform, expected_shape): for d in data3_postcached: self.assertTupleEqual(d["image"].shape, expected_shape) + # update the data to cache + test_data_ex = [ + { + "image": os.path.join(tempdir, "test_image2.nii.gz"), + "label": os.path.join(tempdir, "test_label2.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image1.nii.gz"), + "label": os.path.join(tempdir, "test_label1.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + }, + ] + dataset_postcached.update_data(data=test_data_ex) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image2.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra1.nii.gz")) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index cd2daa7c9c..0c587c9248 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -19,7 +19,7 @@ import numpy as np from parameterized import parameterized -from monai.data import SmartCacheDataset, DataLoader +from monai.data import DataLoader, SmartCacheDataset from monai.transforms import Compose, Lambda, LoadImaged TEST_CASE_1 = [0.1, 0, Compose([LoadImaged(keys=["image", "label", "extra"])])] From 96dae50bf5fd0afbccd9bf07caed2c74cde9ac9d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 24 Jun 2021 17:56:23 +0800 Subject: [PATCH 3/8] [DLMED] fix flake8 issue Signed-off-by: Nic Ma --- monai/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 785388f2c6..9c5a153850 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -183,7 +183,7 @@ def update_data(self, data: Sequence): """ self.data = data - if self.cache_dir.exists(): + if self.cache_dir is not None and self.cache_dir.exists(): shutil.rmtree(self.cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) From 22e9bc3d4afe434ce49f7416fa5af62dc4546c17 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 24 Jun 2021 19:34:45 +0800 Subject: [PATCH 4/8] [DLMED] fix CI tests Signed-off-by: Nic Ma --- tests/test_cachedataset.py | 2 +- tests/test_smartcachedataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 055e07884c..0c256aacb0 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -95,7 +95,7 @@ def test_update_data(self): ) num_workers = 2 if sys.platform == "linux" else 0 - dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1, persistent_workers=False) + dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list1[i] * 10]], d) diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 0c587c9248..4aaa6c7301 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -144,7 +144,7 @@ def test_update_data(self): ) num_workers = 2 if sys.platform == "linux" else 0 - dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1, persistent_workers=False) + dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=1) dataset.start() for i, d in enumerate(dataloader): From ad88557fb138f470ad2225208fbb6390bb0794cc Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 25 Jun 2021 15:57:29 +0800 Subject: [PATCH 5/8] [DLMED] enhance dataset Signed-off-by: Nic Ma --- monai/data/dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 7ae841bf8e..55a1d9e18c 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -715,6 +715,7 @@ def __init__( self.set_random_state(seed=seed) data = copy(data) self.randomize(data) + self.shuffle = shuffle super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) if self._cache is None: @@ -753,6 +754,10 @@ def update_data(self, data: Sequence): if self.is_started(): warnings.warn("SmartCacheDataset is not shutdown yet, shutdown it directly.") self.shutdown() + + if self.shuffle: + data = copy(data) + self.randomize(data) return super().update_data(data) def randomize(self, data: Sequence) -> None: From d95414568f36a478ffc04f883a3535dd7197f140 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 25 Jun 2021 21:02:38 +0800 Subject: [PATCH 6/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/dataset.py | 26 ++++++++++----- tests/test_cachedataset.py | 4 +-- tests/test_lmdbdataset.py | 58 ++++++++++++++++++++++----------- tests/test_persistentdataset.py | 22 ++++++------- tests/test_smartcachedataset.py | 4 +-- 5 files changed, 71 insertions(+), 43 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 55a1d9e18c..4a801f2e20 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -176,15 +176,14 @@ def __init__( if not self.cache_dir.is_dir(): raise ValueError("cache_dir must be a directory.") - def update_data(self, data: Sequence): + def set_data(self, data: Sequence): """ - Update the input data and delete all the out-dated cache content. + Set the input data and delete all the out-dated cache content. """ self.data = data - if self.cache_dir is not None and self.cache_dir.exists(): - shutil.rmtree(self.cache_dir) + shutil.rmtree(self.cache_dir, ignore_errors=True) self.cache_dir.mkdir(parents=True, exist_ok=True) def _pre_transform(self, item_transformed): @@ -418,6 +417,14 @@ def __init__( self._read_env = None print(f"Accessing lmdb file: {self.db_file.absolute()}.") + def set_data(self, data: Sequence): + """ + Set the input data and delete all the out-dated cache content. + + """ + super().set_data(data=data) + self._read_env = None + def _fill_cache_start_reader(self): # create cache self.lmdb_kwargs["readonly"] = False @@ -472,6 +479,7 @@ def _cachecheck(self, item_transformed): if self._read_env is None: self._read_env = self._fill_cache_start_reader() with self._read_env.begin(write=False) as txn: + print("!!!!!!!!", item_transformed) data = txn.get(self.hash_func(item_transformed)) if data is None: warnings.warn("LMDBDataset: cache key not found, running fallback caching.") @@ -572,9 +580,9 @@ def __init__( self.num_workers = max(int(self.num_workers), 1) self._cache: List = self._fill_cache() - def update_data(self, data: Sequence): + def set_data(self, data: Sequence): """ - Update the input data and re-run deterministic transforms to generate cache content. + Set the input data and run deterministic transforms to generate cache content. Note: should call this func after an entire epoch and must set `persisten_workers=False` in PyTorch DataLoader, because it needs to create new worker processes based on new @@ -744,9 +752,9 @@ def __init__( self._compute_data_idx() - def update_data(self, data: Sequence): + def set_data(self, data: Sequence): """ - Update the input data and re-run deterministic transforms to generate cache content. + Set the input data and run deterministic transforms to generate cache content. Note: should call `shutdown()` before calling this func. @@ -758,7 +766,7 @@ def update_data(self, data: Sequence): if self.shuffle: data = copy(data) self.randomize(data) - return super().update_data(data) + super().set_data(data) def randomize(self, data: Sequence) -> None: try: diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index 0c256aacb0..bbb8143631 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -81,7 +81,7 @@ def test_shape(self, transform, expected_shape): for d in data3: self.assertTupleEqual(d["image"].shape, expected_shape) - def test_update_data(self): + def test_set_data(self): data_list1 = list(range(10)) transform = Lambda(func=lambda x: np.array([x * 10])) @@ -101,7 +101,7 @@ def test_update_data(self): # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) - dataset.update_data(data=data_list2) + dataset.set_data(data=data_list2) # rerun with updated cache content for i, d in enumerate(dataloader): np.testing.assert_allclose([[data_list2[i] * 10]], d) diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 7ae8e57e7a..3e3aed709f 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -158,25 +158,45 @@ def test_shape(self, transform, expected_shape, kwargs=None): data1_postcached = dataset_postcached[0] data2_postcached = dataset_postcached[1] - if transform is None: - self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) - self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) - else: - self.assertTupleEqual(data1_precached["image"].shape, expected_shape) - self.assertTupleEqual(data1_precached["label"].shape, expected_shape) - self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) - self.assertTupleEqual(data2_precached["image"].shape, expected_shape) - self.assertTupleEqual(data2_precached["label"].shape, expected_shape) - self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) - - self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) - self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) - self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) - self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + if transform is None: + self.assertEqual(data1_precached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_precached["label"], os.path.join(tempdir, "test_label2.nii.gz")) + self.assertEqual(data1_postcached["image"], os.path.join(tempdir, "test_image1.nii.gz")) + self.assertEqual(data2_postcached["extra"], os.path.join(tempdir, "test_extra2.nii.gz")) + else: + self.assertTupleEqual(data1_precached["image"].shape, expected_shape) + self.assertTupleEqual(data1_precached["label"].shape, expected_shape) + self.assertTupleEqual(data1_precached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_precached["image"].shape, expected_shape) + self.assertTupleEqual(data2_precached["label"].shape, expected_shape) + self.assertTupleEqual(data2_precached["extra"].shape, expected_shape) + + self.assertTupleEqual(data1_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data1_postcached["extra"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["image"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["label"].shape, expected_shape) + self.assertTupleEqual(data2_postcached["extra"].shape, expected_shape) + + # update the data to cache + test_data_new = [ + { + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), + }, + { + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), + }, + ] + dataset_postcached.set_data(data=test_data_new) + # test new exchanged cache content + if transform is None: + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) @skip_if_windows diff --git a/tests/test_persistentdataset.py b/tests/test_persistentdataset.py index 5577c742ff..8446f566ef 100644 --- a/tests/test_persistentdataset.py +++ b/tests/test_persistentdataset.py @@ -124,24 +124,24 @@ def test_shape(self, transform, expected_shape): self.assertTupleEqual(d["image"].shape, expected_shape) # update the data to cache - test_data_ex = [ + test_data_new = [ { - "image": os.path.join(tempdir, "test_image2.nii.gz"), - "label": os.path.join(tempdir, "test_label2.nii.gz"), - "extra": os.path.join(tempdir, "test_extra2.nii.gz"), + "image": os.path.join(tempdir, "test_image1_new.nii.gz"), + "label": os.path.join(tempdir, "test_label1_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra1_new.nii.gz"), }, { - "image": os.path.join(tempdir, "test_image1.nii.gz"), - "label": os.path.join(tempdir, "test_label1.nii.gz"), - "extra": os.path.join(tempdir, "test_extra1.nii.gz"), + "image": os.path.join(tempdir, "test_image2_new.nii.gz"), + "label": os.path.join(tempdir, "test_label2_new.nii.gz"), + "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), }, ] - dataset_postcached.update_data(data=test_data_ex) + dataset_postcached.set_data(data=test_data_new) # test new exchanged cache content if transform is None: - self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image2.nii.gz")) - self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label2.nii.gz")) - self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra1.nii.gz")) + self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) + self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) + self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) if __name__ == "__main__": diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index a451c19688..e2675f4d8c 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -127,7 +127,7 @@ def test_shuffle(self): dataset.shutdown() - def test_update_data(self): + def test_set_data(self): data_list1 = list(range(10)) transform = Lambda(func=lambda x: np.array([x * 10])) @@ -157,7 +157,7 @@ def test_update_data(self): dataset.shutdown() # update the datalist and fill the cache content data_list2 = list(range(-10, 0)) - dataset.update_data(data=data_list2) + dataset.set_data(data=data_list2) # restart the dataset dataset.start() # rerun with updated cache content From 320c625f4f96f1324f46e06c6f93aa53e4a7fb7e Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sat, 26 Jun 2021 07:35:23 +0800 Subject: [PATCH 7/8] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/dataset.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 4a801f2e20..f91c8e261b 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -130,8 +130,7 @@ class PersistentDataset(Dataset): Subsequent uses of a dataset directly read pre-processed results from `cache_dir` followed by applying the random dependant parts of transform processing. - If want to update the input data and recompute cache content during training, just call `update_data()` - after several epochs. + During training call update_data() to update input data and recompute cache content. Note: The input data must be a list of file paths and will hash them as cache keys. @@ -479,7 +478,6 @@ def _cachecheck(self, item_transformed): if self._read_env is None: self._read_env = self._fill_cache_start_reader() with self._read_env.begin(write=False) as txn: - print("!!!!!!!!", item_transformed) data = txn.get(self.hash_func(item_transformed)) if data is None: warnings.warn("LMDBDataset: cache key not found, running fallback caching.") @@ -537,8 +535,8 @@ class CacheDataset(Dataset): ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached. - If want to update the input data and recompute cache content during training, just call `update_data()` - after several epochs, note that it requires `persistent_workers=False` in the PyTorch DataLoader. + During training call update_data() to update input data and recompute cache content, note that it requires + `persistent_workers=False` in the PyTorch DataLoader. Note: `CacheDataset` executes non-random transforms and prepares cache content in the main process before @@ -676,8 +674,8 @@ class SmartCacheDataset(Randomizable, CacheDataset): 3. Call `update_cache()` before every epoch to replace training items. 4. Call `shutdown()` when training ends. - If want to update the input data and recompute cache content during training, just call - `shutdown()` and `update_data()` to stop and update, then call `start()` to restart. + During training call update_data() to update input data and recompute cache content, note to call + `shutdown()` to stop first, then update data and call `start()` to restart. Note: This replacement will not work for below cases: From 90fb35acb9c5e16e4acc8e9c51d875085eab49e4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 29 Jun 2021 00:30:16 +0800 Subject: [PATCH 8/8] [DLMED] update_data -> set_data Signed-off-by: Nic Ma --- monai/data/dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f91c8e261b..67eb00af5c 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -130,7 +130,7 @@ class PersistentDataset(Dataset): Subsequent uses of a dataset directly read pre-processed results from `cache_dir` followed by applying the random dependant parts of transform processing. - During training call update_data() to update input data and recompute cache content. + During training call `set_data()` to update input data and recompute cache content. Note: The input data must be a list of file paths and will hash them as cache keys. @@ -535,7 +535,7 @@ class CacheDataset(Dataset): ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform and the outcome not cached. - During training call update_data() to update input data and recompute cache content, note that it requires + During training call `set_data()` to update input data and recompute cache content, note that it requires `persistent_workers=False` in the PyTorch DataLoader. Note: @@ -674,7 +674,7 @@ class SmartCacheDataset(Randomizable, CacheDataset): 3. Call `update_cache()` before every epoch to replace training items. 4. Call `shutdown()` when training ends. - During training call update_data() to update input data and recompute cache content, note to call + During training call `set_data()` to update input data and recompute cache content, note to call `shutdown()` to stop first, then update data and call `start()` to restart. Note: