Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,10 @@ class SmartCacheDataset(Randomizable, CacheDataset):
4. Call `shutdown()` when training ends.

Note:
This replacement will not work if setting the `multiprocessing_context` of DataLoader to `spawn`
or on windows(the default multiprocessing method is `spawn`) and setting `num_workers` greater than 0.
This replacement will not work for below cases:
1. Set the `multiprocessing_context` of DataLoader to `spawn`.
2. Run on windows(the default multiprocessing method is `spawn`) with `num_workers` greater than 0.
3. Set the `persistent_workers` of DataLoader to `True` with `num_workers` greater than 0.

If using MONAI workflows, please add `SmartCacheHandler` to the handler list of trainer,
otherwise, please make sure to call `start()`, `update_cache()`, `shutdown()` during training.
Expand Down
6 changes: 5 additions & 1 deletion tests/test_handler_smartcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest

import torch
from ignite.engine import Engine

from monai.data import SmartCacheDataset
from monai.handlers import SmartCacheHandler
from tests.utils import SkipIfBeforePyTorchVersion


@SkipIfBeforePyTorchVersion((1, 7))
class TestHandlerSmartCache(unittest.TestCase):
def test_content(self):
data = [0, 1, 2, 3, 4, 5, 6, 7, 8]
Expand All @@ -37,7 +40,8 @@ def _train_func(engine, batch):

# set up testing handler
dataset = SmartCacheDataset(data, transform=None, replace_rate=0.2, cache_num=5, shuffle=False)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)
workers = 2 if sys.platform == "linux" else 0
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5, num_workers=workers, persistent_workers=False)
SmartCacheHandler(dataset).attach(engine)

engine.run(data_loader, max_epochs=5)
Expand Down