diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 4d18bd4e0d..b863bb58fe 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -413,7 +413,10 @@ def __init__( self.lmdb_kwargs = lmdb_kwargs or {} if not self.lmdb_kwargs.get("map_size", 0): self.lmdb_kwargs["map_size"] = 1024 ** 4 # default map_size - self._read_env = None + self._env = None + # lmdb is single-writer multi-reader by default + # the cache is created without multi-threading + self._read_env = self._fill_cache_start_reader(show_progress=self.progress) print(f"Accessing lmdb file: {self.db_file.absolute()}.") def set_data(self, data: Sequence): @@ -422,43 +425,56 @@ def set_data(self, data: Sequence): """ super().set_data(data=data) - self._read_env = None + self._env = None + self._read_env = self._fill_cache_start_reader(show_progress=self.progress) - def _fill_cache_start_reader(self): + def _fill_cache_start_reader(self, show_progress=True): + """ + Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write. + This method can be used with multiple processes, but it may have a negative impact on the performance. + + Args: + show_progress: whether to show the progress bar if possible. + """ # create cache self.lmdb_kwargs["readonly"] = False - env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) - if self.progress and not has_tqdm: + if self._env is None: + self._env = lmdb.open(path=f"{self.db_file}", subdir=False, **self.lmdb_kwargs) + env = self._env + if show_progress and not has_tqdm: warnings.warn("LMDBDataset: tqdm is not installed. not displaying the caching progress.") - for item in tqdm(self.data) if has_tqdm and self.progress else self.data: - key = self.hash_func(item) - done, retry, val = False, 5, None - while not done and retry > 0: - try: - with env.begin(write=True) as txn: - with txn.cursor() as cursor: + with env.begin(write=False) as search_txn: + for item in tqdm(self.data) if has_tqdm and show_progress else self.data: + key = self.hash_func(item) + done, retry, val = False, 5, None + while not done and retry > 0: + try: + with search_txn.cursor() as cursor: done = cursor.set_key(key) - if done: - continue + if done: + continue if val is None: val = self._pre_transform(deepcopy(item)) # keep the original hashed val = pickle.dumps(val, protocol=self.pickle_protocol) - txn.put(key, val) - done = True - except lmdb.MapFullError: - done, retry = False, retry - 1 + with env.begin(write=True) as txn: + txn.put(key, val) + done = True + except lmdb.MapFullError: + done, retry = False, retry - 1 + size = env.info()["map_size"] + new_size = size * 2 + warnings.warn( + f"Resizing the cache database from {int(size) >> 20}MB" f" to {int(new_size) >> 20}MB." + ) + env.set_mapsize(new_size) + except lmdb.MapResizedError: + # the mapsize is increased by another process + # set_mapsize with a size of 0 to adopt the new size + env.set_mapsize(0) + if not done: # still has the map full error size = env.info()["map_size"] - new_size = size * 2 - warnings.warn(f"Resizing the cache database from {int(size) >> 20}MB to {int(new_size) >> 20}MB.") - env.set_mapsize(new_size) - except lmdb.MapResizedError: - # the mapsize is increased by another process - # set_mapsize with a size of 0 to adopt the new size, - env.set_mapsize(0) - if not done: # still has the map full error - size = env.info()["map_size"] - env.close() - raise ValueError(f"LMDB map size reached, increase size above current size of {size}.") + env.close() + raise ValueError(f"LMDB map size reached, increase size above current size of {size}.") size = env.info()["map_size"] env.close() # read-only database env @@ -476,7 +492,7 @@ def _cachecheck(self, item_transformed): """ if self._read_env is None: - self._read_env = self._fill_cache_start_reader() + self._read_env = self._fill_cache_start_reader(show_progress=False) with self._read_env.begin(write=False) as txn: data = txn.get(self.hash_func(item_transformed)) if data is None: diff --git a/tests/test_lmdbdataset.py b/tests/test_lmdbdataset.py index 3e3aed709f..fbdb651297 100644 --- a/tests/test_lmdbdataset.py +++ b/tests/test_lmdbdataset.py @@ -191,12 +191,15 @@ def test_shape(self, transform, expected_shape, kwargs=None): "extra": os.path.join(tempdir, "test_extra2_new.nii.gz"), }, ] - dataset_postcached.set_data(data=test_data_new) # test new exchanged cache content if transform is None: + dataset_postcached.set_data(data=test_data_new) self.assertEqual(dataset_postcached[0]["image"], os.path.join(tempdir, "test_image1_new.nii.gz")) self.assertEqual(dataset_postcached[0]["label"], os.path.join(tempdir, "test_label1_new.nii.gz")) self.assertEqual(dataset_postcached[1]["extra"], os.path.join(tempdir, "test_extra2_new.nii.gz")) + else: + with self.assertRaises(RuntimeError): + dataset_postcached.set_data(data=test_data_new) # filename list updated, files do not exist @skip_if_windows