Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ def __init__(
progress: bool = True,
copy_cache: bool = True,
as_contiguous: bool = True,
hash_as_key: bool = False,
hash_func: Callable[..., bytes] = pickle_hashing,
) -> None:
"""
Args:
Expand All @@ -695,19 +697,29 @@ def __init__(
may set `copy=False` for better performance.
as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
it may help improve the performance of following logic.
hash_as_key: whether to compute hash value of input data as the key to save cache,
if key exists, avoid saving duplicated content. it can help save memory when
the dataset has duplicated items or augmented dataset.
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.

"""
if not isinstance(transform, Compose):
transform = Compose(transform)
super().__init__(data=data, transform=transform)
self.set_num = cache_num # tracking the user-provided `cache_num` option
self.set_rate = cache_rate # tracking the user-provided `cache_rate` option
self.progress = progress
self.copy_cache = copy_cache
self.as_contiguous = as_contiguous
self.cache_num = min(int(cache_num), int(len(data) * cache_rate), len(data))
self.hash_as_key = hash_as_key
self.hash_func = hash_func
self.num_workers = num_workers
if self.num_workers is not None:
self.num_workers = max(int(self.num_workers), 1)
self._cache: List = self._fill_cache()
self.cache_num = 0
self._cache: Union[List, Dict] = []
self.set_data(data)

def set_data(self, data: Sequence):
"""
Expand All @@ -718,8 +730,21 @@ def set_data(self, data: Sequence):
generated cache content.

"""
self.data = data
self._cache = self._fill_cache()

def _compute_cache():
self.cache_num = min(int(self.set_num), int(len(self.data) * self.set_rate), len(self.data))
return self._fill_cache()

if self.hash_as_key:
# only compute cache for the unique items of dataset
mapping = {self.hash_func(v): v for v in data}
self.data = list(mapping.values())
cache_ = _compute_cache()
self._cache = dict(zip(list(mapping)[: self.cache_num], cache_))
self.data = data
else:
self.data = data
self._cache = _compute_cache()

def _fill_cache(self) -> List:
if self.cache_num <= 0:
Expand Down Expand Up @@ -754,14 +779,21 @@ def _load_cache_item(self, idx: int):
return item

def _transform(self, index: int):
if index % len(self) >= self.cache_num: # support negative index
index_: Any = index
if self.hash_as_key:
key = self.hash_func(self.data[index])
if key in self._cache:
# if existing in cache, get the index
index_ = key # if using hash as cache keys, set the key

if isinstance(index_, int) and index_ % len(self) >= self.cache_num: # support negative index
# no cache for this index, execute all the transforms directly
return super()._transform(index)
return super()._transform(index_)
# load data from cache and execute from the first random transform
start_run = False
if self._cache is None:
self._cache = self._fill_cache()
data = self._cache[index]
data = self._cache[index_]
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
Expand Down Expand Up @@ -862,10 +894,14 @@ def __init__(
) -> None:
if shuffle:
self.set_random_state(seed=seed)
data = copy(data)
self.randomize(data)
self.shuffle = shuffle

self._start_pos: int = 0
self._update_lock: threading.Lock = threading.Lock()
self._round: int = 1
self._replace_done: bool = False
self._replace_mgr: Optional[threading.Thread] = None

super().__init__(data, transform, cache_num, cache_rate, num_init_workers, progress, copy_cache, as_contiguous)
if self._cache is None:
self._cache = self._fill_cache()
Expand All @@ -884,13 +920,6 @@ def __init__(
self._replace_num: int = min(math.ceil(self.cache_num * replace_rate), len(data) - self.cache_num)
self._replacements: List[Any] = [None for _ in range(self._replace_num)]
self._replace_data_idx: List[int] = list(range(self._replace_num))

self._start_pos: int = 0
self._update_lock: threading.Lock = threading.Lock()
self._round: int = 1
self._replace_done: bool = False
self._replace_mgr: Optional[threading.Thread] = None

self._compute_data_idx()

def set_data(self, data: Sequence):
Expand Down
70 changes: 49 additions & 21 deletions tests/test_cachedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,12 @@ class TestCacheDataset(unittest.TestCase):
def test_shape(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
nib.save(test_image, os.path.join(tempdir, "test_image1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_extra1.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_image2.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_label2.nii.gz"))
nib.save(test_image, os.path.join(tempdir, "test_extra2.nii.gz"))
test_data = [
{
"image": os.path.join(tempdir, "test_image1.nii.gz"),
"label": os.path.join(tempdir, "test_label1.nii.gz"),
"extra": os.path.join(tempdir, "test_extra1.nii.gz"),
},
{
"image": os.path.join(tempdir, "test_image2.nii.gz"),
"label": os.path.join(tempdir, "test_label2.nii.gz"),
"extra": os.path.join(tempdir, "test_extra2.nii.gz"),
},
]
test_data = []
for i in ["1", "2"]:
for k in ["image", "label", "extra"]:
nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz"))
test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]})

dataset = CacheDataset(data=test_data, transform=transform, cache_rate=0.5, as_contiguous=True)
data1 = dataset[0]
data2 = dataset[1]
Expand All @@ -68,9 +56,9 @@ def test_shape(self, transform, expected_shape):
self.assertEqual(len(data3), 1)

if transform is None:
self.assertEqual(data1["image"], os.path.join(tempdir, "test_image1.nii.gz"))
self.assertEqual(data2["label"], os.path.join(tempdir, "test_label2.nii.gz"))
self.assertEqual(data4["image"], os.path.join(tempdir, "test_image2.nii.gz"))
self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz"))
self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz"))
self.assertEqual(data4["image"], os.path.join(tempdir, "image2.nii.gz"))
else:
self.assertTupleEqual(data1["image"].shape, expected_shape)
self.assertTupleEqual(data1["label"].shape, expected_shape)
Expand Down Expand Up @@ -195,6 +183,46 @@ def test_thread_safe(self, persistent_workers, cache_workers, loader_workers):
self.assertListEqual(expected, [y.item() for y in loader])
self.assertListEqual(expected, [y.item() for y in loader])

@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_hash_as_key(self, transform, expected_shape):
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
test_data = []
for i in ["1", "2", "2", "3", "3"]:
for k in ["image", "label", "extra"]:
nib.save(test_image, os.path.join(tempdir, f"{k}{i}.nii.gz"))
test_data.append({k: os.path.join(tempdir, f"{k}{i}.nii.gz") for k in ["image", "label", "extra"]})

dataset = CacheDataset(data=test_data, transform=transform, cache_num=4, num_workers=2, hash_as_key=True)
self.assertEqual(len(dataset), 5)
# ensure no duplicated cache content
self.assertEqual(len(dataset._cache), 3)
self.assertEqual(dataset.cache_num, 3)
data1 = dataset[0]
data2 = dataset[1]
data3 = dataset[-1]
# test slice indices
data4 = dataset[0:-1]
self.assertEqual(len(data4), 4)

if transform is None:
self.assertEqual(data1["image"], os.path.join(tempdir, "image1.nii.gz"))
self.assertEqual(data2["label"], os.path.join(tempdir, "label2.nii.gz"))
self.assertEqual(data3["image"], os.path.join(tempdir, "image3.nii.gz"))
else:
self.assertTupleEqual(data1["image"].shape, expected_shape)
self.assertTupleEqual(data2["label"].shape, expected_shape)
self.assertTupleEqual(data3["image"].shape, expected_shape)
for d in data4:
self.assertTupleEqual(d["image"].shape, expected_shape)

test_data2 = test_data[:3]
dataset.set_data(data=test_data2)
self.assertEqual(len(dataset), 3)
# ensure no duplicated cache content
self.assertEqual(len(dataset._cache), 2)
self.assertEqual(dataset.cache_num, 2)


if __name__ == "__main__":
unittest.main()