Skip to content

CacheDataset problem when using WeightedRandomSampler together with Randomizable InvertibleTransform #2282

Description

@dlotfi

Describe the bug
CacheDataset together with a random data sampler with replacement and one or more Randomizable InvertibleTransform transforms causes an error in collation of 'x_transforms' keys:

RuntimeError: each element in list of batch should be of equal size
Collate error on the key 'img_transforms' of dictionary data.

To Reproduce
In order to reproduce this error:

  • Dataset should be CacheDataset (or probably one of its variants like SmartCacheDataset)
  • The list of transforms should contain:
    • At least one InvertibleTransform and non-Randomizable transform (e.g. Rotate90d) before any Randomizable transform
    • At least one InvertibleTransform and Randomizable transform (e.g. RandFlipd)
  • A random sampler with replacement=True should be used in DataLoader
import tempfile
from PIL.Image import fromarray
from monai.data import create_test_image_2d, CacheDataset, DataLoader
from monai.transforms import Compose, LoadImaged, AddChanneld, Rotate90d, RandFlipd
from torch.utils.data import RandomSampler

with tempfile.TemporaryDirectory() as tempdir:
    data_files = []
    for i in range(40):
        im, seg = create_test_image_2d(128, 128, num_seg_classes=1)
        file_path = tempdir + f"img{i:d}.png"
        fromarray(im.astype("uint8")).save(file_path)
        data_files.append({"img": file_path})

    transforms = Compose(
        [
            LoadImaged(keys=["img"]),
            AddChanneld(keys=["img"]),
            Rotate90d(keys=["img"]),  # InvertibleTransform
            RandFlipd(keys=["img"])   # Randomizable and InvertibleTransform
        ]
    )
    dataset = CacheDataset(data=data_files, transform=transforms, num_workers=0)
    data_loader = DataLoader(
        dataset,
        batch_size=4,
        sampler=RandomSampler(data_source=dataset, replacement=True),
        num_workers=0
    )

    for batch_data in data_loader:
        img = batch_data["img"]

What causes this error?
The Randomizable transforms are applied to data items cached in CacheDataset on request. If this Randomizable transform is also an InvertibleTransform, it pushes an entry to 'x_transforms' key by calling push_transform function. In this function if 'x_transforms' key already exists (due to previous non-Randomizable InvertibleTransform which was applied during loading of the dataset) the new information will append to the list. If a data sampler with replacement is used, a data item may be requested any number of times. This leads to 'x_transforms' lists of different data items having different lengths and list_data_collate couldn't collate them to form a batch.

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions