diff --git a/monai/data/dataset.py b/monai/data/dataset.py index bfb3d8b86d..b8ecd2b35a 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -614,8 +614,10 @@ class SmartCacheDataset(Randomizable, CacheDataset): 4. Call `shutdown()` when training ends. Note: - This replacement will not work if setting the `multiprocessing_context` of DataLoader to `spawn` - or on windows(the default multiprocessing method is `spawn`) and setting `num_workers` greater than 0. + This replacement will not work for below cases: + 1. Set the `multiprocessing_context` of DataLoader to `spawn`. + 2. Run on windows(the default multiprocessing method is `spawn`) with `num_workers` greater than 0. + 3. Set the `persistent_workers` of DataLoader to `True` with `num_workers` greater than 0. If using MONAI workflows, please add `SmartCacheHandler` to the handler list of trainer, otherwise, please make sure to call `start()`, `update_cache()`, `shutdown()` during training. diff --git a/tests/test_handler_smartcache.py b/tests/test_handler_smartcache.py index cfe68e98e2..b67f1226cd 100644 --- a/tests/test_handler_smartcache.py +++ b/tests/test_handler_smartcache.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest import torch @@ -16,8 +17,10 @@ from monai.data import SmartCacheDataset from monai.handlers import SmartCacheHandler +from tests.utils import SkipIfBeforePyTorchVersion +@SkipIfBeforePyTorchVersion((1, 7)) class TestHandlerSmartCache(unittest.TestCase): def test_content(self): data = [0, 1, 2, 3, 4, 5, 6, 7, 8] @@ -37,7 +40,8 @@ def _train_func(engine, batch): # set up testing handler dataset = SmartCacheDataset(data, transform=None, replace_rate=0.2, cache_num=5, shuffle=False) - data_loader = torch.utils.data.DataLoader(dataset, batch_size=5) + workers = 2 if sys.platform == "linux" else 0 + data_loader = torch.utils.data.DataLoader(dataset, batch_size=5, num_workers=workers, persistent_workers=False) SmartCacheHandler(dataset).attach(engine) engine.run(data_loader, max_epochs=5)