Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,15 @@ class Dataset(_TorchDataset):
}, }, }]
"""

def __init__(self, data: Sequence, transform: Optional[Callable] = None, progress: bool = True) -> None:
def __init__(self, data: Sequence, transform: Optional[Callable] = None) -> None:
"""
Args:
data: input data to load and transform to generate dataset for model.
transform: a callable data transform on input data.
progress: whether to display a progress bar.

"""
self.data = data
self.transform = transform
self.progress = progress

def __len__(self) -> int:
return len(self.data)
Expand Down Expand Up @@ -118,7 +117,6 @@ def __init__(
transform: Union[Sequence[Callable], Callable],
cache_dir: Optional[Union[Path, str]] = None,
hash_func: Callable[..., bytes] = pickle_hashing,
progress: bool = True,
) -> None:
"""
Args:
Expand All @@ -133,11 +131,11 @@ def __init__(
If the cache_dir doesn't exist, will automatically create it.
hash_func: a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
progress: whether to display a progress bar.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform, progress=progress)
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
self.hash_func = hash_func
if self.cache_dir is not None:
Expand Down Expand Up @@ -350,7 +348,8 @@ def __init__(
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, progress=progress)
super().__init__(data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func)
self.progress = progress
if not self.cache_dir:
raise ValueError("cache_dir must be specified.")
self.db_file = self.cache_dir / f"{db_name}.lmdb"
Expand Down Expand Up @@ -490,7 +489,8 @@ def __init__(
"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform, progress=progress)
super().__init__(data=data, transform=transform)
self.progress = progress
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
self.num_workers = num_workers
if self.num_workers is not None:
Expand Down Expand Up @@ -591,7 +591,7 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_init_workers: Optional[int] = None,
num_replace_workers: int = 0,
num_replace_workers: Optional[int] = None,
progress: bool = True,
) -> None:
"""
Expand All @@ -606,8 +606,8 @@ def __init__(
num_init_workers: the number of worker threads to initialize the cache for first epoch.
If num_init_workers is None then the number returned by os.cpu_count() is used.
num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
if 0, run in main thread, no separate thread will open.
progress: whether to display a progress bar.
If num_replace_workers is None then the number returned by os.cpu_count() is used.
progress: whether to display a progress bar when caching for the first epoch.

"""
super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress)
Expand All @@ -617,7 +617,10 @@ def __init__(
warnings.warn("cache_num is greater or equal than dataset length, fall back to regular CacheDataset.")
if replace_rate <= 0:
raise ValueError("replace_rate must be greater than 0, otherwise, please use CacheDataset.")
self.num_replace_workers: int = num_replace_workers

self.num_replace_workers: Optional[int] = num_replace_workers
if self.num_replace_workers is not None:
self.num_replace_workers = max(int(self.num_replace_workers), 1)

self._total_num: int = len(data)
self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num)
Expand Down Expand Up @@ -747,12 +750,9 @@ def _compute_replacements(self):
It can support multi-threads to accelerate the computation progress.

"""
if self.num_replace_workers > 0:
with ThreadPool(self.num_replace_workers) as p:
p.map(self._replace_cache_thread, list(range(self._replace_num)))
else:
for i in range(self._replace_num):
self._replace_cache_thread(i)
with ThreadPool(self.num_replace_workers) as p:
p.map(self._replace_cache_thread, list(range(self._replace_num)))

self._replace_done = True

def _try_manage_replacement(self, check_round):
Expand Down
8 changes: 5 additions & 3 deletions tests/test_smartcachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@

TEST_CASE_2 = [0.1, 4, Compose([LoadImaged(keys=["image", "label", "extra"])])]

TEST_CASE_3 = [0.1, 4, None]
TEST_CASE_3 = [0.1, None, Compose([LoadImaged(keys=["image", "label", "extra"])])]

TEST_CASE_4 = [0.5, 2, Compose([LoadImaged(keys=["image", "label", "extra"])])]
TEST_CASE_4 = [0.1, 4, None]

TEST_CASE_5 = [0.5, 2, Compose([LoadImaged(keys=["image", "label", "extra"])])]


class TestSmartCacheDataset(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_shape(self, replace_rate, num_replace_workers, transform):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[8, 8, 8]), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
Expand Down