|
38 | 38 | }, |
39 | 39 | '--profile': { |
40 | 40 | 'action': 'store_true', |
41 | | - } |
| 41 | + }, |
| 42 | + '--persistent_workers': { |
| 43 | + 'action': 'store_true', |
| 44 | + }, |
| 45 | + '--prefetch_factor': { |
| 46 | + 'type': int, |
| 47 | + }, |
| 48 | + '--loader_prefetch_size': { |
| 49 | + 'type': int, |
| 50 | + }, |
| 51 | + '--device_prefetch_size': { |
| 52 | + 'type': int, |
| 53 | + }, |
| 54 | + '--host_to_device_transfer_threads': { |
| 55 | + 'type': int, |
| 56 | + }, |
| 57 | + '--use_optimized_kwargs': { |
| 58 | + 'type': str, |
| 59 | + }, |
42 | 60 | } |
43 | 61 |
|
44 | 62 | FLAGS = args_parse.parse_common_options( |
|
81 | 99 | momentum=0.9, |
82 | 100 | lr=0.1, |
83 | 101 | target_accuracy=0.0, |
| 102 | + persistent_workers=False, |
| 103 | + prefetch_factor=16, |
| 104 | + loader_prefetch_size=8, |
| 105 | + device_prefetch_size=4, |
| 106 | + num_workers=8, |
| 107 | + host_to_device_transfer_threads=1, |
84 | 108 | ) |
| 109 | + |
| 110 | +# Best config to achieve peak performance based on TPU version |
| 111 | +# 1. It is recommended to use this config in conjuntion with XLA_USE_BF16=1 Flag. |
| 112 | +# 2. Hyperparameters can be tuned to further improve the accuracy. |
| 113 | +# usage: python3 /usr/share/pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 \ |
| 114 | +# --fake_data --num_epochs=10 --log_steps=300 \ |
| 115 | +# --profile --use_optimized_kwargs=tpuv4 --drop_last |
| 116 | +OPTIMIZED_KWARGS = { |
| 117 | + 'tpuv4': |
| 118 | + dict( |
| 119 | + batch_size=128, |
| 120 | + test_set_batch_size=128, |
| 121 | + num_epochs=18, |
| 122 | + momentum=0.9, |
| 123 | + lr=0.1, |
| 124 | + target_accuracy=0.0, |
| 125 | + persistent_workers=True, |
| 126 | + prefetch_factor=32, |
| 127 | + loader_prefetch_size=128, |
| 128 | + device_prefetch_size=1, |
| 129 | + num_workers=16, |
| 130 | + host_to_device_transfer_threads=4, |
| 131 | + ) |
| 132 | +} |
| 133 | + |
85 | 134 | MODEL_SPECIFIC_DEFAULTS = { |
86 | | - # Override some of the args in DEFAULT_KWARGS, or add them to the dict |
| 135 | + # Override some of the args in DEFAULT_KWARGS/OPTIMIZED_KWARGS, or add them to the dict |
87 | 136 | # if they don't exist. |
88 | 137 | 'resnet50': |
89 | 138 | dict( |
90 | | - DEFAULT_KWARGS, **{ |
| 139 | + OPTIMIZED_KWARGS.get(FLAGS.use_optimized_kwargs, DEFAULT_KWARGS), |
| 140 | + **{ |
91 | 141 | 'lr': 0.5, |
92 | 142 | 'lr_scheduler_divide_every_n_epochs': 20, |
93 | 143 | 'lr_scheduler_divisor': 5, |
@@ -192,14 +242,18 @@ def train_imagenet(): |
192 | 242 | sampler=train_sampler, |
193 | 243 | drop_last=FLAGS.drop_last, |
194 | 244 | shuffle=False if train_sampler else True, |
195 | | - num_workers=FLAGS.num_workers) |
| 245 | + num_workers=FLAGS.num_workers, |
| 246 | + persistent_workers=FLAGS.persistent_workers, |
| 247 | + prefetch_factor=FLAGS.prefetch_factor) |
196 | 248 | test_loader = torch.utils.data.DataLoader( |
197 | 249 | test_dataset, |
198 | 250 | batch_size=FLAGS.test_set_batch_size, |
199 | 251 | sampler=test_sampler, |
200 | 252 | drop_last=FLAGS.drop_last, |
201 | 253 | shuffle=False, |
202 | | - num_workers=FLAGS.num_workers) |
| 254 | + num_workers=FLAGS.num_workers, |
| 255 | + persistent_workers=FLAGS.persistent_workers, |
| 256 | + prefetch_factor=FLAGS.prefetch_factor) |
203 | 257 |
|
204 | 258 | torch.manual_seed(42) |
205 | 259 |
|
@@ -273,8 +327,19 @@ def test_loop_fn(loader, epoch): |
273 | 327 | accuracy = xm.mesh_reduce('test_accuracy', accuracy, np.mean) |
274 | 328 | return accuracy |
275 | 329 |
|
276 | | - train_device_loader = pl.MpDeviceLoader(train_loader, device) |
277 | | - test_device_loader = pl.MpDeviceLoader(test_loader, device) |
| 330 | + train_device_loader = pl.MpDeviceLoader( |
| 331 | + train_loader, |
| 332 | + device, |
| 333 | + loader_prefetch_size=FLAGS.loader_prefetch_size, |
| 334 | + device_prefetch_size=FLAGS.device_prefetch_size, |
| 335 | + host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads) |
| 336 | + test_device_loader = pl.MpDeviceLoader( |
| 337 | + test_loader, |
| 338 | + device, |
| 339 | + loader_prefetch_size=FLAGS.loader_prefetch_size, |
| 340 | + device_prefetch_size=FLAGS.device_prefetch_size, |
| 341 | + host_to_device_transfer_threads=FLAGS.host_to_device_transfer_threads) |
| 342 | + |
278 | 343 | accuracy, max_accuracy = 0.0, 0.0 |
279 | 344 | for epoch in range(1, FLAGS.num_epochs + 1): |
280 | 345 | xm.master_print('Epoch {} train begin {}'.format(epoch, test_utils.now())) |
|
0 commit comments