You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/pjrt.md
+3-7Lines changed: 3 additions & 7 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -28,8 +28,6 @@ _New features in PyTorch/XLA r2.0_:
28
28
* New `xm.rendezvous` implementation that scales to thousands of TPU cores
29
29
*[experimental]`torch.distributed` support for TPU v2 and v3, including
30
30
`pjrt://``init_method`
31
-
*[experimental] Single-host GPU support in PJRT. Multi-host support coming
32
-
soon!
33
31
34
32
## TL;DR
35
33
@@ -192,8 +190,6 @@ for more information.
192
190
193
191
### GPU
194
192
195
-
*Warning: GPU support is still highly experimental!*
196
-
197
193
### Single-node GPU training
198
194
199
195
To use GPUs with PJRT, simply set `PJRT_DEVICE=CUDA` and configure
@@ -226,7 +222,7 @@ PJRT_DEVICE=CUDA torchrun \
226
222
-`--nnodes`: how many GPU machines to be used.
227
223
-`--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1.
228
224
-`--nproc_per_node`: the number of GPU devices to be used on the current machine.
229
-
-`--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form <host>:<port>. The `host` will be the internal IP address. The port can be any available port on the machine.
225
+
-`--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form `host:port`. The `host` will be the internal IP address. The `port` can be any available port on the machine. For single-node training/inference, this parameter can be omitted.
230
226
231
227
For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on the first GPU machine machine_0, run
232
228
@@ -235,7 +231,7 @@ For example, if you want to train on 2 GPU machines: machine_0 and machine_1, on
the difference between the 2 commands above are `--node_rank` and potentially `--nproc_per_node` if you want to use different number of GPU devices on each machine. All the rest are identical. For more information about `torchrun`, please refer to this [page](https://pytorch.org/docs/stable/elastic/run.html).
Copy file name to clipboardExpand all lines: docs/spmd.md
+42Lines changed: 42 additions & 0 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -357,6 +357,48 @@ Unlike existing DDP and FSDP, under the SPMD mode, there is always a single proc
357
357
There is no code change required to go from single TPU host to TPU Pod if you construct your mesh and partition spec based on the number of devices instead of some hardcode constant. To run the PyTorch/XLA workload on TPU Pod, please refer to the [Pods section](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#pods) of our PJRT guide.
358
358
359
359
360
+
### Running SPMD on GPU
361
+
362
+
PyTorch/XLA supports SPMD on NVIDIA GPU (single-node or multi-nodes). The training/inference script remains the same as the one used forTPU, such as this [ResNet script](https://github.com/pytorch/xla/blob/1dc78948c0c9d018d8d0d2b4cce912552ab27083/test/spmd/test_train_spmd_imagenet.py). To execute the script using SPMD, we leverage `torchrun`:
363
+
364
+
```
365
+
PJRT_DEVICE=CUDA \
366
+
torchrun \
367
+
--nnodes=${NUM_GPU_MACHINES} \
368
+
--node_rank=${RANK_OF_CURRENT_MACHINE} \
369
+
--nproc_per_node=1 \
370
+
--rdzv_endpoint="<MACHINE_0_IP_ADDRESS>:<PORT>" \
371
+
training_or_inference_script_using_spmd.py
372
+
```
373
+
-`--nnodes`: how many GPU machines to be used.
374
+
-`--node_rank`: the index of the current GPU machines. The value can be 0, 1, ..., ${NUMBER_GPU_VM}-1.
375
+
-`--nproc_per_node`: the value must be 1 due to the SPMD requirement.
376
+
-`--rdzv_endpoint`: the endpoint of the GPU machine with node_rank==0, in the form `host:port`. The host will be the internal IP address. The `port` can be any available port on the machine. For single-node training/inference, this parameter can be omitted.
377
+
378
+
For example, if you want to train a ResNet model on 2GPU machines using SPMD, you can run the script below on the first machine:
0 commit comments