|
1 | | -# Fori_loop |
2 | | -`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation |
3 | | -like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps |
4 | | -of each iteration. `fori_loop` might help memory utilization and might help faster compilation. |
| 1 | +# `While_loop` optimize memory utilization and compilation |
5 | 2 |
|
6 | | -User could use `fori_loop` like this: |
7 | | -```python |
8 | | -from torch_xla.experimental.fori_loop import fori_loop |
9 | | -res = fori_loop(upper, lower, /*user defined*/body_fun, init) |
10 | | -``` |
11 | | - |
12 | | -current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too. |
| 3 | +<br> |
13 | 4 |
|
14 | | -For detailed implementation: |
15 | | -- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop), |
16 | | -like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the |
17 | | -native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only. |
| 5 | +### `while_loop` |
| 6 | +`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by |
| 7 | +[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66). |
| 8 | +PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`. |
18 | 9 |
|
19 | | -- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan), |
20 | | -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator. |
21 | | -This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference. |
22 | | - |
23 | | -# while_loop |
24 | | -`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in |
25 | | -[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69). |
26 | | -PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While. |
27 | | - |
28 | | -User could use `while_loop` like this: |
| 10 | +#### Usage: |
29 | 11 | ```python |
30 | 12 | import torch_xla.experimental.fori_loop |
31 | 13 | from torch._higher_order_ops.while_loop import while_loop |
32 | | -res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init) |
| 14 | +result = while_loop(cond_fn, body_fn, init) |
33 | 15 | ``` |
34 | | -current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too. |
35 | | - |
| 16 | +- `cond_fn`: User-defined condition function. |
| 17 | +- `body_fn`: User-defined loop body function. |
| 18 | +- `init`: Initial values (tuple or list). |
36 | 19 |
|
37 | | -# [WIP]scan |
38 | | -like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd. |
39 | | -`scan` is WIP. |
40 | | - |
41 | | - |
42 | | -# Simple user guide |
43 | | -User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times: |
44 | | - |
45 | | -### simple example with pure python for loop |
46 | | -```bash |
47 | | -# python |
48 | | ->>> import torch |
49 | | ->>> init = torch.tensor([0], dtype=torch.int32) |
50 | | ->>> one_value = torch.ones(1, dtype=torch.int32) |
51 | | ->>> |
52 | | ->>> for i in range(10): |
53 | | -... init = init + one_value |
54 | | -... |
55 | | ->>> init |
56 | | -tensor([10], dtype=torch.int32) |
57 | | -``` |
58 | | -
|
59 | | -### simple example with `while_loop`: |
| 20 | +#### simple example with `while_loop`: |
60 | 21 | ```bash |
61 | 22 | # PJRT_DEVICE=TPU python |
62 | 23 | >>> import torch |
63 | 24 | >>> import torch_xla |
64 | 25 | >>> import torch_xla.experimental.fori_loop |
65 | | ->>> from torch_xla.experimental.fori_loop import fori_loop |
66 | 26 | >>> from torch._higher_order_ops.while_loop import while_loop |
67 | 27 | >>> import torch_xla.core.xla_model as xm |
68 | | ->>> import torch_xla.core.xla_builder as xb |
69 | 28 | >>> |
70 | 29 | >>> device = xm.xla_device() |
71 | 30 | >>> |
72 | | ->>> def cond_fn(init, limit_value): |
73 | | -... return limit_value[0] >= init[0] |
| 31 | +>>> def cond_fn(iteri, x): |
| 32 | +... return iteri > 0 |
74 | 33 | ... |
75 | | ->>> def body_fn(init, limit_value): |
76 | | -... one_value = torch.ones(1, dtype=torch.int32, device=device) |
77 | | -... return (torch.add(init, one_value), limit_value.clone()) |
| 34 | +>>> def body_fn(iteri, x): |
| 35 | +... return iteri - 1, torch.add(x, 1) |
78 | 36 | ... |
79 | | ->>> init = torch.tensor([0], dtype=torch.int32, device=device) |
80 | | ->>> limit_value = torch.tensor([10], dtype=torch.int32, device=device) |
81 | | ->>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value)) |
82 | | ->>> res_ |
| 37 | +>>> init_val = torch.tensor(3, device=device) |
| 38 | +>>> iteri = torch.tensor(10, device=device) |
| 39 | +>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val)) |
| 40 | +>>> res |
83 | 41 | FunctionalTensor(lvl=0, value=\ |
84 | | -tensor([11], device='xla:0', dtype=torch.int32)) |
| 42 | +tensor(13, device='xla:0')) |
85 | 43 | ``` |
86 | 44 |
|
87 | | -### simple example with `fori_loop`: |
| 45 | +<br> |
| 46 | + |
| 47 | +## Control group test case |
| 48 | +For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times: |
| 49 | + |
| 50 | +### Control group example with pure python `while` loop |
88 | 51 | ```bash |
89 | 52 | # PJRT_DEVICE=TPU python |
90 | 53 | >>> import torch |
91 | 54 | >>> import torch_xla |
92 | | ->>> import torch_xla.experimental.fori_loop |
93 | | ->>> from torch_xla.experimental.fori_loop import fori_loop |
94 | | ->>> from torch._higher_order_ops.while_loop import while_loop |
95 | 55 | >>> import torch_xla.core.xla_model as xm |
96 | | ->>> import torch_xla.core.xla_builder as xb |
97 | 56 | >>> |
98 | 57 | >>> device = xm.xla_device() |
99 | 58 | >>> |
100 | | ->>> lower = torch.tensor([2], dtype=torch.int32, device=device) |
101 | | ->>> upper = torch.tensor([52], dtype=torch.int32, device=device) |
102 | | ->>> plus_value = torch.tensor([1], dtype=torch.int32, device=device) |
103 | | ->>> init_val = torch.tensor([1], dtype=torch.int32, device=device) |
| 59 | +>>> init_val = torch.tensor(1, device=device) |
| 60 | +>>> iteri = torch.tensor(50, device=device) |
104 | 61 | >>> |
105 | | ->>> def body_fun(*argus): |
106 | | -... plus_value, init_val = argus |
107 | | -... return plus_value, torch.add(plus_value, init_val) |
| 62 | +>>> while iteri > 0: |
| 63 | +... init_val = init_val + 1 |
| 64 | +... iteri -= 1 |
108 | 65 | ... |
109 | | ->>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val) |
110 | | ->>> res_ |
111 | | -tensor([51], device='xla:0', dtype=torch.int32) |
| 66 | +>>> init_val |
| 67 | +tensor(51, device='xla:0') |
112 | 68 | ``` |
113 | 69 |
|
114 | | -For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3 |
| 70 | +
|
| 71 | +
|
| 72 | +PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape |
0 commit comments