diff --git a/monai/data/dataset.py b/monai/data/dataset.py index bb5a98ba1e..c032e65af6 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -52,16 +52,15 @@ class Dataset(_TorchDataset): }, }, }] """ - def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None: + def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None: """ Args: data: input data to load and transform to generate dataset for model. transform: a callable data transform on input data. - progress: whether to display a progress bar. + """ self.data = data self.transform = transform - self.progress = progress def __len__(self) -> int: return len(self.data) @@ -118,7 +117,6 @@ def __init__( transform: Union[Sequence[Callable], Callable], cache_dir: Optional[Union[Path, str]] = None, hash_func: Callable[..., bytes] = pickle_hashing, - progress: bool = True, ) -> None: """ Args: @@ -133,11 +131,11 @@ def __init__( If the cache_dir doesn't exist, will automatically create it. hash_func: a callable to compute hash from data items to be cached. defaults to `monai.data.utils.pickle_hashing`. - progress: whether to display a progress bar. + """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform, progress=progress) + super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None self.hash_func = hash_func if self.cache_dir is not None: @@ -350,7 +348,8 @@ def __init__( lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ - super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress) + super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func) + self.progress = progress if not self.cache_dir: raise ValueError("cache_dir must be specified.") self.db_file = self.cache_dir / f"{db_name}.lmdb" @@ -490,7 +489,8 @@ def __init__( """ if not isinstance(transform, Compose): transform = Compose(transform) - super().__init__(data=data, transform=transform, progress=progress) + super().__init__(data=data, transform=transform) + self.progress = progress self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data)) self.num_workers = num_workers if self.num_workers is not None: @@ -591,7 +591,7 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_init_workers: Optional[int] = None, - num_replace_workers: int = 0, + num_replace_workers: Optional[int] = None, progress: bool = True, ) -> None: """ @@ -606,8 +606,8 @@ def __init__( num_init_workers: the number of worker threads to initialize the cache for first epoch. If num_init_workers is None then the number returned by os.cpu_count() is used. num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. - if 0, run in main thread, no separate thread will open. - progress: whether to display a progress bar. + If num_replace_workers is None then the number returned by os.cpu_count() is used. + progress: whether to display a progress bar when caching for the first epoch. """ super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress) @@ -617,7 +617,10 @@ def __init__( warnings.warn("cache_num is greater or equal than dataset length, fall back to regular CacheDataset.") if replace_rate <= 0: raise ValueError("replace_rate must be greater than 0, otherwise, please use CacheDataset.") - self.num_replace_workers: int = num_replace_workers + + self.num_replace_workers: Optional[int] = num_replace_workers + if self.num_replace_workers is not None: + self.num_replace_workers = max(int(self.num_replace_workers), 1) self._total_num: int = len(data) self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num) @@ -747,12 +750,9 @@ def _compute_replacements(self): It can support multi-threads to accelerate the computation progress. """ - if self.num_replace_workers > 0: - with ThreadPool(self.num_replace_workers) as p: - p.map(self._replace_cache_thread, list(range(self._replace_num))) - else: - for i in range(self._replace_num): - self._replace_cache_thread(i) + with ThreadPool(self.num_replace_workers) as p: + p.map(self._replace_cache_thread, list(range(self._replace_num))) + self._replace_done = True def _try_manage_replacement(self, check_round): diff --git a/tests/test_smartcachedataset.py b/tests/test_smartcachedataset.py index 3d1a051a83..7ebb2858d2 100644 --- a/tests/test_smartcachedataset.py +++ b/tests/test_smartcachedataset.py @@ -24,13 +24,15 @@ TEST_CASE_2 = [0.1, 4, Compose([LoadImaged(keys=["image", "label", "extra"])])] -TEST_CASE_3 = [0.1, 4, None] +TEST_CASE_3 = [0.1, None, Compose([LoadImaged(keys=["image", "label", "extra"])])] -TEST_CASE_4 = [0.5, 2, Compose([LoadImaged(keys=["image", "label", "extra"])])] +TEST_CASE_4 = [0.1, 4, None] + +TEST_CASE_5 = [0.5, 2, Compose([LoadImaged(keys=["image", "label", "extra"])])] class TestSmartCacheDataset(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_shape(self, replace_rate, num_replace_workers, transform): test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]), np.eye(4)) with tempfile.TemporaryDirectory() as tempdir: