diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 22e8bdb610..2c263b3e32 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -789,7 +789,8 @@ def __init__( if self.num_workers is not None: self.num_workers = max(int(self.num_workers), 1) self.cache_num = 0 - self._cache: Union[List, Dict] = [] + self._cache: List = [] + self._hash_keys: List = [] self.set_data(data) def set_data(self, data: Sequence): @@ -801,37 +802,41 @@ def set_data(self, data: Sequence): generated cache content. """ + self.data = data - def _compute_cache(): - self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data)) - return self._fill_cache() + def _compute_cache_num(data_len: int): + self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len) if self.hash_as_key: - # only compute cache for the unique items of dataset - mapping = {self.hash_func(v): v for v in data} - self.data = list(mapping.values()) - cache_ = _compute_cache() - self._cache = dict(zip(list(mapping)[: self.cache_num], cache_)) - self.data = data + # only compute cache for the unique items of dataset, and record the last index for duplicated items + mapping = {self.hash_func(v): i for i, v in enumerate(data)} + _compute_cache_num(len(mapping)) + self._hash_keys = list(mapping)[: self.cache_num] + indices = list(mapping.values())[: self.cache_num] else: - self.data = data - self._cache = _compute_cache() + _compute_cache_num(len(self.data)) + indices = list(range(self.cache_num)) + self._cache = self._fill_cache(indices) + + def _fill_cache(self, indices=None) -> List: + """ + Compute and fill the cache content from data source. - def _fill_cache(self) -> List: + Args: + indices: target indices in the `self.data` source to compute cache. + if None, use the first `cache_num` items. + + """ if self.cache_num <= 0: return [] + if indices is None: + indices = list(range(self.cache_num)) if self.progress and not has_tqdm: warnings.warn("tqdm is not installed, will not show the caching progress bar.") with ThreadPool(self.num_workers) as p: if self.progress and has_tqdm: - return list( - tqdm( - p.imap(self._load_cache_item, range(self.cache_num)), - total=self.cache_num, - desc="Loading dataset", - ) - ) - return list(p.imap(self._load_cache_item, range(self.cache_num))) + return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset")) + return list(p.imap(self._load_cache_item, indices)) def _load_cache_item(self, idx: int): """ @@ -850,21 +855,24 @@ def _load_cache_item(self, idx: int): return item def _transform(self, index: int): - index_: Any = index + cache_index = None if self.hash_as_key: key = self.hash_func(self.data[index]) - if key in self._cache: - # if existing in cache, get the index - index_ = key # if using hash as cache keys, set the key + if key in self._hash_keys: + # if existing in cache, try to get the index in cache + cache_index = self._hash_keys.index(key) + elif index % len(self) < self.cache_num: # support negative index + cache_index = index - if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index + if cache_index is None: # no cache for this index, execute all the transforms directly - return super()._transform(index_) + return super()._transform(index) + + if self._cache is None: + raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.") + data = self._cache[cache_index] # load data from cache and execute from the first random transform start_run = False - if self._cache is None: - self._cache = self._fill_cache() - data = self._cache[index_] if not isinstance(self.transform, Compose): raise ValueError("transform must be an instance of monai.transforms.Compose.") for _transform in self.transform.transforms: diff --git a/tests/test_cachedataset.py b/tests/test_cachedataset.py index e30a34b335..86ebced9f3 100644 --- a/tests/test_cachedataset.py +++ b/tests/test_cachedataset.py @@ -202,6 +202,7 @@ def test_hash_as_key(self, transform, expected_shape): self.assertEqual(len(dataset), 5) # ensure no duplicated cache content self.assertEqual(len(dataset._cache), 3) + self.assertEqual(len(dataset._hash_keys), 3) self.assertEqual(dataset.cache_num, 3) data1 = dataset[0] data2 = dataset[1]