@@ -123,23 +123,67 @@ Our github contains many useful docs on working with different aspects of PyTorc
123123
124124## Getting Started
125125
126+ Following here are guides for two modes:
127+ - Single process: one Python interpreter controlling a single GPU/TPU at a time
128+ - Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs
129+ found on the system
130+
131+ Another mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on
132+ the system. Multi processing is more complex, and is not compatible with SPMD. This
133+ tutorial does not dive into SPMD. For more on that, check our
134+ [ SPMD guide] ( https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_basic.md ) .
135+
136+ ### Simple single process
137+
138+ To update your exisitng training loop, make the following changes:
139+
140+ ``` diff
141+ + import torch_xla
142+
143+ def train(model, training_data, ...):
144+ ...
145+ for inputs, labels in train_loader:
146+ + with torch_xla.step():
147+ inputs, labels = training_data[i]
148+ + inputs, labels = inputs.to('xla'), labels.to('xla')
149+ optimizer.zero_grad()
150+ outputs = model(inputs)
151+ loss = loss_fn(outputs, labels)
152+ loss.backward()
153+ optimizer.step()
154+
155+ + torch_xla.sync()
156+ ...
157+
158+ if __name__ == '__main__':
159+ ...
160+ + # Move the model paramters to your XLA device
161+ + model.to('xla')
162+ train(model, training_data, ...)
163+ ...
164+ ```
165+
166+ The changes above should get your model to train on the TPU.
167+
168+ ### Multi processing
169+
126170To update your existing training loop, make the following changes:
127171
128172``` diff
129173- import torch.multiprocessing as mp
130- + import torch_xla as xla
174+ + import torch_xla
131175+ import torch_xla.core.xla_model as xm
132176
133177 def _mp_fn(index):
134178 ...
135179
136180+ # Move the model paramters to your XLA device
137- + model.to(xla .device())
181+ + model.to(torch_xla .device())
138182
139183 for inputs, labels in train_loader:
140- + with xla .step():
184+ + with torch_xla .step():
141185+ # Transfer data to the XLA device. This happens asynchronously.
142- + inputs, labels = inputs.to(xla .device()), labels.to(xla .device())
186+ + inputs, labels = inputs.to(torch_xla .device()), labels.to(torch_xla .device())
143187 optimizer.zero_grad()
144188 outputs = model(inputs)
145189 loss = loss_fn(outputs, labels)
@@ -150,8 +194,8 @@ To update your existing training loop, make the following changes:
150194
151195 if __name__ == '__main__':
152196- mp.spawn(_mp_fn, args=(), nprocs=world_size)
153- + # xla .launch automatically selects the correct world size
154- + xla .launch(_mp_fn, args=())
197+ + # torch_xla .launch automatically selects the correct world size
198+ + torch_xla .launch(_mp_fn, args=())
155199```
156200
157201If you're using ` DistributedDataParallel ` , make the following changes:
0 commit comments