Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions monai/data/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ class DistributedWeightedRandomSampler(DistributedSampler):
num_samples_per_rank: number of samples to draw for every rank, sample from
the distributed subset of dataset.
if None, default to the length of dataset split by DistributedSampler.
replacement: if ``True``, samples are drawn with replacement, otherwise, they are
drawn without replacement, which means that when a sample index is drawn for a row,
it cannot be drawn again for that row, default to True.
generator: PyTorch Generator used in sampling.
even_divisible: if False, different ranks can have different data length.
for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].'
Expand All @@ -90,7 +87,6 @@ def __init__(
dataset: Dataset,
weights: Sequence[float],
num_samples_per_rank: Optional[int] = None,
replacement: bool = True,
generator: Optional[torch.Generator] = None,
even_divisible: bool = True,
num_replicas: Optional[int] = None,
Expand All @@ -107,16 +103,17 @@ def __init__(
**kwargs,
)
self.weights = weights
self.num_samples_per_rank = num_samples_per_rank
self.replacement = replacement
self.num_samples_per_rank = num_samples_per_rank if num_samples_per_rank is not None else self.num_samples
self.generator = generator

def __iter__(self):
indices = list(super().__iter__())
num_samples = self.num_samples_per_rank if self.num_samples_per_rank is not None else self.num_samples
weights = torch.as_tensor([self.weights[i] for i in indices], dtype=torch.double)
# sample based on the provided weights
rand_tensor = torch.multinomial(weights, num_samples, self.replacement, generator=self.generator)
rand_tensor = torch.multinomial(weights, self.num_samples_per_rank, True, generator=self.generator)

for i in rand_tensor:
yield indices[i]

def __len__(self):
return self.num_samples_per_rank
Comment thread
Nic-Ma marked this conversation as resolved.
4 changes: 1 addition & 3 deletions tests/test_distributed_weighted_random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@

class DistributedWeightedRandomSamplerTest(DistTestCase):
@DistCall(nnodes=1, nproc_per_node=2)
def test_replacement(self):
def test_sampling(self):
data = [1, 2, 3, 4, 5]
weights = [1, 2, 3, 4, 5]
sampler = DistributedWeightedRandomSampler(
weights=weights,
replacement=True,
dataset=data,
shuffle=False,
generator=torch.Generator().manual_seed(0),
Expand All @@ -46,7 +45,6 @@ def test_num_samples(self):
sampler = DistributedWeightedRandomSampler(
weights=weights,
num_samples_per_rank=5,
replacement=True,
dataset=data,
shuffle=False,
generator=torch.Generator().manual_seed(123),
Expand Down