diff --git a/monai/data/samplers.py b/monai/data/samplers.py index 5fea6959de..8bba79c9b0 100644 --- a/monai/data/samplers.py +++ b/monai/data/samplers.py @@ -70,9 +70,6 @@ class DistributedWeightedRandomSampler(DistributedSampler): num_samples_per_rank: number of samples to draw for every rank, sample from the distributed subset of dataset. if None, default to the length of dataset split by DistributedSampler. - replacement: if ``True``, samples are drawn with replacement, otherwise, they are - drawn without replacement, which means that when a sample index is drawn for a row, - it cannot be drawn again for that row, default to True. generator: PyTorch Generator used in sampling. even_divisible: if False, different ranks can have different data length. for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].' @@ -90,7 +87,6 @@ def __init__( dataset: Dataset, weights: Sequence[float], num_samples_per_rank: Optional[int] = None, - replacement: bool = True, generator: Optional[torch.Generator] = None, even_divisible: bool = True, num_replicas: Optional[int] = None, @@ -107,16 +103,17 @@ def __init__( **kwargs, ) self.weights = weights - self.num_samples_per_rank = num_samples_per_rank - self.replacement = replacement + self.num_samples_per_rank = num_samples_per_rank if num_samples_per_rank is not None else self.num_samples self.generator = generator def __iter__(self): indices = list(super().__iter__()) - num_samples = self.num_samples_per_rank if self.num_samples_per_rank is not None else self.num_samples weights = torch.as_tensor([self.weights[i] for i in indices], dtype=torch.double) # sample based on the provided weights - rand_tensor = torch.multinomial(weights, num_samples, self.replacement, generator=self.generator) + rand_tensor = torch.multinomial(weights, self.num_samples_per_rank, True, generator=self.generator) for i in rand_tensor: yield indices[i] + + def __len__(self): + return self.num_samples_per_rank diff --git a/tests/test_distributed_weighted_random_sampler.py b/tests/test_distributed_weighted_random_sampler.py index 6e27e78d4c..b8e088fdcf 100644 --- a/tests/test_distributed_weighted_random_sampler.py +++ b/tests/test_distributed_weighted_random_sampler.py @@ -21,12 +21,11 @@ class DistributedWeightedRandomSamplerTest(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) - def test_replacement(self): + def test_sampling(self): data = [1, 2, 3, 4, 5] weights = [1, 2, 3, 4, 5] sampler = DistributedWeightedRandomSampler( weights=weights, - replacement=True, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(0), @@ -46,7 +45,6 @@ def test_num_samples(self): sampler = DistributedWeightedRandomSampler( weights=weights, num_samples_per_rank=5, - replacement=True, dataset=data, shuffle=False, generator=torch.Generator().manual_seed(123),