Skip to content

Commit f642baf

Browse files
authored
backport dataloader thread from #4693 (#4727)
1 parent 17f7da2 commit f642baf

File tree

2 files changed

+81
-11
lines changed

2 files changed

+81
-11
lines changed

test/test_train_mp_imagenet.py

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,25 @@
3838
},
3939
'--profile': {
4040
'action': 'store_true',
41-
}
41+
},
42+
'--persistent_workers': {
43+
'action': 'store_true',
44+
},
45+
'--prefetch_factor': {
46+
'type': int,
47+
},
48+
'--loader_prefetch_size': {
49+
'type': int,
50+
},
51+
'--device_prefetch_size': {
52+
'type': int,
53+
},
54+
'--host_to_device_transfer_threads': {
55+
'type': int,
56+
},
57+
'--use_optimized_kwargs': {
58+
'type': str,
59+
},
4260
}
4361

4462
FLAGS = args_parse.parse_common_options(
@@ -81,13 +99,45 @@
8199
momentum=0.9,
82100
lr=0.1,
83101
target_accuracy=0.0,
102+
persistent_workers=False,
103+
prefetch_factor=16,
104+
loader_prefetch_size=8,
105+
device_prefetch_size=4,
106+
num_workers=8,
107+
host_to_device_transfer_threads=1,
84108
)
109+
110+
# Best config to achieve peak performance based on TPU version
111+
# 1. It is recommended to use this config in conjuntion with XLA_USE_BF16=1 Flag.
112+
# 2. Hyperparameters can be tuned to further improve the accuracy.
113+
# usage: python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 \
114+
# --fake_data --num_epochs=10 --log_steps=300 \
115+
# --profile --use_optimized_kwargs=tpuv4 --drop_last
116+
OPTIMIZED_KWARGS = {
117+
'tpuv4':
118+
dict(
119+
batch_size=128,
120+
test_set_batch_size=128,
121+
num_epochs=18,
122+
momentum=0.9,
123+
lr=0.1,
124+
target_accuracy=0.0,
125+
persistent_workers=True,
126+
prefetch_factor=32,
127+
loader_prefetch_size=128,
128+
device_prefetch_size=1,
129+
num_workers=16,
130+
host_to_device_transfer_threads=4,
131+
)
132+
}
133+
85134
MODEL_SPECIFIC_DEFAULTS = {
86-
# Override some of the args in DEFAULT_KWARGS, or add them to the dict
135+
# Override some of the args in DEFAULT_KWARGS/OPTIMIZED_KWARGS, or add them to the dict
87136
# if they don't exist.
88137
'resnet50':
89138
dict(
90-
DEFAULT_KWARGS, **{
139+
OPTIMIZED_KWARGS.get(FLAGS.use_optimized_kwargs, DEFAULT_KWARGS),
140+
**{
91141
'lr': 0.5,
92142
'lr_scheduler_divide_every_n_epochs': 20,
93143
'lr_scheduler_divisor': 5,
@@ -192,14 +242,18 @@ def train_imagenet():
192242
sampler=train_sampler,
193243
drop_last=FLAGS.drop_last,
194244
shuffle=False if train_sampler else True,
195-
num_workers=FLAGS.num_workers)
245+
num_workers=FLAGS.num_workers,
246+
persistent_workers=FLAGS.persistent_workers,
247+
prefetch_factor=FLAGS.prefetch_factor)
196248
test_loader = torch.utils.data.DataLoader(
197249
test_dataset,
198250
batch_size=FLAGS.test_set_batch_size,
199251
sampler=test_sampler,
200252
drop_last=FLAGS.drop_last,
201253
shuffle=False,
202-
num_workers=FLAGS.num_workers)
254+
num_workers=FLAGS.num_workers,
255+
persistent_workers=FLAGS.persistent_workers,
256+
prefetch_factor=FLAGS.prefetch_factor)
203257

204258
torch.manual_seed(42)
205259

@@ -273,8 +327,19 @@ def test_loop_fn(loader, epoch):
273327
accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean)
274328
return accuracy
275329

276-
train_device_loader = pl.MpDeviceLoader(train_loader, device)
277-
test_device_loader = pl.MpDeviceLoader(test_loader, device)
330+
train_device_loader = pl.MpDeviceLoader(
331+
train_loader,
332+
device,
333+
loader_prefetch_size=FLAGS.loader_prefetch_size,
334+
device_prefetch_size=FLAGS.device_prefetch_size,
335+
host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads)
336+
test_device_loader = pl.MpDeviceLoader(
337+
test_loader,
338+
device,
339+
loader_prefetch_size=FLAGS.loader_prefetch_size,
340+
device_prefetch_size=FLAGS.device_prefetch_size,
341+
host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads)
342+
278343
accuracy, max_accuracy = 0.0, 0.0
279344
for epoch in range(1, FLAGS.num_epochs + 1):
280345
xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now()))

torch_xla/distributed/parallel_loader.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ class ParallelLoader(object):
6969
where the worker threads deposit tensors which have already been sent to
7070
devices.
7171
Default: 4
72+
host_to_device_transfer_threads (int, optional): The number of threads that
73+
work in parallel to transfer data from loader queue to device queue.
74+
Default: 1
7275
"""
7376

7477
def __init__(self,
@@ -77,7 +80,8 @@ def __init__(self,
7780
batchdim=0,
7881
batches_per_execution=1,
7982
loader_prefetch_size=8,
80-
device_prefetch_size=4):
83+
device_prefetch_size=4,
84+
host_to_device_transfer_threads=1):
8185
self._loader = loader
8286
self._devices = [torch.device(x) for x in devices]
8387
self._batchdim = batchdim
@@ -91,9 +95,10 @@ def __init__(self,
9195
thread.daemon = True
9296
thread.start()
9397
for dqueue in self._queues.values():
94-
thread = threading.Thread(target=self._worker, args=(dqueue,))
95-
thread.daemon = True
96-
thread.start()
98+
for i in range(host_to_device_transfer_threads):
99+
thread = threading.Thread(target=self._worker, args=(dqueue,))
100+
thread.daemon = True
101+
thread.start()
97102

98103
def per_device_loader(self, device):
99104
"""Retrieves the loader iterator object for the given device.

0 commit comments

Comments
 (0)