Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
ec463d6
Merge pull request #397 from Project-MONAI/dev
Nic-Ma Apr 4, 2022
ca62306
Merge pull request #429 from Project-MONAI/dev
Nic-Ma Jul 8, 2022
af77f46
Merge pull request #450 from Project-MONAI/dev
Nic-Ma Oct 25, 2022
30d022a
[DLMED] unify cache structure
Nic-Ma Oct 25, 2022
19652e9
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 25, 2022
3b78916
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 26, 2022
3df0f48
[DLMED] update according to comments
Nic-Ma Oct 26, 2022
1f8b351
[DLMED] update according to comments
Nic-Ma Oct 27, 2022
fdb629d
[DLMED] add more test
Nic-Ma Oct 27, 2022
dc98d20
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 28, 2022
1068e5d
[DLMED] update according to comments
Nic-Ma Oct 28, 2022
5d9534e
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 28, 2022
13801ee
[DLMED] update according to comments
Nic-Ma Oct 28, 2022
7fcaa86
Revert "[DLMED] update according to comments"
Nic-Ma Oct 28, 2022
6215ade
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 28, 2022
34bf5ae
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 29, 2022
02abed4
Merge branch 'dev' into 5390-remove-dict-cache
Nic-Ma Oct 31, 2022
82f106f
[DLMED] avoid changing self.data
Nic-Ma Oct 31, 2022
ff5bb25
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2022
9315f68
[MONAI] code formatting
monai-bot Oct 31, 2022
aa162ef
[DLMED] fix typo
Nic-Ma Oct 31, 2022
86cb054
[DLMED] fix doc
Nic-Ma Oct 31, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 38 additions & 30 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,8 @@ def __init__(
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self.cache_num = 0
self._cache: Union[List, Dict] = []
self._cache: List = []
self._hash_keys: List = []
self.set_data(data)

def set_data(self, data: Sequence):
Expand All @@ -801,37 +802,41 @@ def set_data(self, data: Sequence):
generated cache content.

"""
self.data = data

def _compute_cache():
self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data))
return self._fill_cache()
def _compute_cache_num(data_len: int):
self.cache_num = min(int(self.set_num), int(data_len * self.set_rate), data_len)

if self.hash_as_key:
# only compute cache for the unique items of dataset
mapping = {self.hash_func(v): v for v in data}
self.data = list(mapping.values())
cache_ = _compute_cache()
self._cache = dict(zip(list(mapping)[: self.cache_num], cache_))
self.data = data
# only compute cache for the unique items of dataset, and record the last index for duplicated items
mapping = {self.hash_func(v): i for i, v in enumerate(data)}
_compute_cache_num(len(mapping))
self._hash_keys = list(mapping)[: self.cache_num]
Comment thread
Nic-Ma marked this conversation as resolved.
indices = list(mapping.values())[: self.cache_num]
else:
self.data = data
self._cache = _compute_cache()
_compute_cache_num(len(self.data))
indices = list(range(self.cache_num))
self._cache = self._fill_cache(indices)

def _fill_cache(self, indices=None) -> List:
"""
Compute and fill the cache content from data source.

def _fill_cache(self) -> List:
Args:
indices: target indices in the `self.data` source to compute cache.
if None, use the first `cache_num` items.

"""
if self.cache_num <= 0:
return []
if indices is None:
indices = list(range(self.cache_num))
if self.progress and not has_tqdm:
warnings.warn("tqdm is not installed, will not show the caching progress bar.")
with ThreadPool(self.num_workers) as p:
if self.progress and has_tqdm:
return list(
tqdm(
p.imap(self._load_cache_item, range(self.cache_num)),
total=self.cache_num,
desc="Loading dataset",
)
)
return list(p.imap(self._load_cache_item, range(self.cache_num)))
return list(tqdm(p.imap(self._load_cache_item, indices), total=len(indices), desc="Loading dataset"))
return list(p.imap(self._load_cache_item, indices))

def _load_cache_item(self, idx: int):
"""
Expand All @@ -850,21 +855,24 @@ def _load_cache_item(self, idx: int):
return item

def _transform(self, index: int):
index_: Any = index
cache_index = None
if self.hash_as_key:
key = self.hash_func(self.data[index])
if key in self._cache:
# if existing in cache, get the index
index_ = key # if using hash as cache keys, set the key
if key in self._hash_keys:
# if existing in cache, try to get the index in cache
cache_index = self._hash_keys.index(key)
elif index % len(self) < self.cache_num: # support negative index
cache_index = index

if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index
if cache_index is None:
# no cache for this index, execute all the transforms directly
return super()._transform(index_)
return super()._transform(index)
Comment thread
Nic-Ma marked this conversation as resolved.

if self._cache is None:
raise RuntimeError("cache buffer is not initialized, please call `set_data()` first.")
data = self._cache[cache_index]
# load data from cache and execute from the first random transform
start_run = False
if self._cache is None:
self._cache = self._fill_cache()
data = self._cache[index_]
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
Expand Down
1 change: 1 addition & 0 deletions tests/test_cachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def test_hash_as_key(self, transform, expected_shape):
self.assertEqual(len(dataset), 5)
# ensure no duplicated cache content
self.assertEqual(len(dataset._cache), 3)
self.assertEqual(len(dataset._hash_keys), 3)
self.assertEqual(dataset.cache_num, 3)
data1 = dataset[0]
data2 = dataset[1]
Expand Down