Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,10 @@ def __init__(
self.lmdb_kwargs = lmdb_kwargs or {}
if not self.lmdb_kwargs.get("map_size", 0):
self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size
self._read_env = None
self._env = None
# lmdb is single-writer multi-reader by default
# the cache is created without multi-threading
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
print(f"Accessing lmdb file: {self.db_file.absolute()}.")

def set_data(self, data: Sequence):
Expand All @@ -422,43 +425,56 @@ def set_data(self, data: Sequence):

"""
super().set_data(data=data)
self._read_env = None
self._env = None
self._read_env = self._fill_cache_start_reader(show_progress=self.progress)
Comment thread
wyli marked this conversation as resolved.

def _fill_cache_start_reader(self):
def _fill_cache_start_reader(self, show_progress=True):
"""
Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
This method can be used with multiple processes, but it may have a negative impact on the performance.

Args:
show_progress: whether to show the progress bar if possible.
"""
# create cache
self.lmdb_kwargs["readonly"] = False
env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
if self.progress and not has_tqdm:
if self._env is None:
self._env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs)
env = self._env
if show_progress and not has_tqdm:
warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.")
for item in tqdm(self.data) if has_tqdm and self.progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
try:
with env.begin(write=True) as txn:
with txn.cursor() as cursor:
with env.begin(write=False) as search_txn:
for item in tqdm(self.data) if has_tqdm and show_progress else self.data:
key = self.hash_func(item)
done, retry, val = False, 5, None
while not done and retry > 0:
try:
with search_txn.cursor() as cursor:
done = cursor.set_key(key)
if done:
continue
if done:
continue
if val is None:
val = self._pre_transform(deepcopy(item)) # keep the original hashed
val = pickle.dumps(val, protocol=self.pickle_protocol)
txn.put(key, val)
done = True
except lmdb.MapFullError:
done, retry = False, retry - 1
with env.begin(write=True) as txn:
txn.put(key, val)
done = True
except lmdb.MapFullError:
done, retry = False, retry - 1
size = env.info()["map_size"]
new_size = size * 2
warnings.warn(
f"Resizing the cache database from {int(size) >> 20}MB" f" to {int(new_size) >> 20}MB."
)
env.set_mapsize(new_size)
except lmdb.MapResizedError:
# the mapsize is increased by another process
# set_mapsize with a size of 0 to adopt the new size
env.set_mapsize(0)
if not done: # still has the map full error
size = env.info()["map_size"]
new_size = size * 2
warnings.warn(f"Resizing the cache database from {int(size) >> 20}MB to {int(new_size) >> 20}MB.")
env.set_mapsize(new_size)
except lmdb.MapResizedError:
# the mapsize is increased by another process
# set_mapsize with a size of 0 to adopt the new size,
env.set_mapsize(0)
if not done: # still has the map full error
size = env.info()["map_size"]
env.close()
raise ValueError(f"LMDB map size reached, increase size above current size of {size}.")
env.close()
raise ValueError(f"LMDB map size reached, increase size above current size of {size}.")
size = env.info()["map_size"]
env.close()
# read-only database env
Expand All @@ -476,7 +492,7 @@ def _cachecheck(self, item_transformed):

"""
if self._read_env is None:
self._read_env = self._fill_cache_start_reader()
self._read_env = self._fill_cache_start_reader(show_progress=False)
with self._read_env.begin(write=False) as txn:
data = txn.get(self.hash_func(item_transformed))
if data is None:
Expand Down
5 changes: 4 additions & 1 deletion tests/test_lmdbdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,15 @@ def test_shape(self, transform, expected_shape, kwargs=None):
"extra": os.path.join(tempdir, "test_extra2_new.nii.gz"),
},
]
dataset_postcached.set_data(data=test_data_new)
# test new exchanged cache content
if transform is None:
dataset_postcached.set_data(data=test_data_new)
self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz"))
self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz"))
self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz"))
else:
with self.assertRaises(RuntimeError):
dataset_postcached.set_data(data=test_data_new) # filename list updated, files do not exist


@skip_if_windows
Expand Down