Skip to content

Commit 2793462

Browse files
ManfeiBaiJackCaoG
andauthored
[backport][Fori_loop|While_loop] Enable while_loop/fori_loop, add test case (#7157) (#7306)
Co-authored-by: JackCaoG <[email protected]>
1 parent a901eb8 commit 2793462

File tree

7 files changed

+287
-247
lines changed

7 files changed

+287
-247
lines changed

docs/fori_loop.md

Lines changed: 37 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,114 +1,72 @@
1-
# Fori_loop
2-
`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
3-
like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
4-
of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
1+
# `While_loop` optimize memory utilization and compilation
52

6-
User could use `fori_loop` like this:
7-
```python
8-
from torch_xla.experimental.fori_loop import fori_loop
9-
res = fori_loop(upper, lower, /*user defined*/body_fun, init)
10-
```
11-
12-
current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
3+
<br>
134

14-
For detailed implementation:
15-
- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
16-
like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
17-
native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
5+
### `while_loop`
6+
`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by
7+
[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66).
8+
PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`.
189

19-
- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
20-
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
21-
This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.
22-
23-
# while_loop
24-
`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
25-
[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
26-
PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.
27-
28-
User could use `while_loop` like this:
10+
#### Usage:
2911
```python
3012
import torch_xla.experimental.fori_loop
3113
from torch._higher_order_ops.while_loop import while_loop
32-
res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
14+
result = while_loop(cond_fn, body_fn, init)
3315
```
34-
current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.
35-
16+
- `cond_fn`: User-defined condition function.
17+
- `body_fn`: User-defined loop body function.
18+
- `init`: Initial values (tuple or list).
3619

37-
# [WIP]scan
38-
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
39-
`scan` is WIP.
40-
41-
42-
# Simple user guide
43-
User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:
44-
45-
### simple example with pure python for loop
46-
```bash
47-
# python
48-
>>> import torch
49-
>>> init = torch.tensor([0], dtype=torch.int32)
50-
>>> one_value = torch.ones(1, dtype=torch.int32)
51-
>>>
52-
>>> for i in range(10):
53-
... init = init + one_value
54-
...
55-
>>> init
56-
tensor([10], dtype=torch.int32)
57-
```
58-
59-
### simple example with `while_loop`:
20+
#### simple example with `while_loop`:
6021
```bash
6122
# PJRT_DEVICE=TPU python
6223
>>> import torch
6324
>>> import torch_xla
6425
>>> import torch_xla.experimental.fori_loop
65-
>>> from torch_xla.experimental.fori_loop import fori_loop
6626
>>> from torch._higher_order_ops.while_loop import while_loop
6727
>>> import torch_xla.core.xla_model as xm
68-
>>> import torch_xla.core.xla_builder as xb
6928
>>>
7029
>>> device = xm.xla_device()
7130
>>>
72-
>>> def cond_fn(init, limit_value):
73-
... return limit_value[0] >= init[0]
31+
>>> def cond_fn(iteri, x):
32+
... return iteri > 0
7433
...
75-
>>> def body_fn(init, limit_value):
76-
... one_value = torch.ones(1, dtype=torch.int32, device=device)
77-
... return (torch.add(init, one_value), limit_value.clone())
34+
>>> def body_fn(iteri, x):
35+
... return iteri - 1, torch.add(x, 1)
7836
...
79-
>>> init = torch.tensor([0], dtype=torch.int32, device=device)
80-
>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
81-
>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
82-
>>> res_
37+
>>> init_val = torch.tensor(3, device=device)
38+
>>> iteri = torch.tensor(10, device=device)
39+
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
40+
>>> res
8341
FunctionalTensor(lvl=0, value=\
84-
tensor([11], device='xla:0', dtype=torch.int32))
42+
tensor(13, device='xla:0'))
8543
```
8644

87-
### simple example with `fori_loop`:
45+
<br>
46+
47+
## Control group test case
48+
For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times:
49+
50+
### Control group example with pure python `while` loop
8851
```bash
8952
# PJRT_DEVICE=TPU python
9053
>>> import torch
9154
>>> import torch_xla
92-
>>> import torch_xla.experimental.fori_loop
93-
>>> from torch_xla.experimental.fori_loop import fori_loop
94-
>>> from torch._higher_order_ops.while_loop import while_loop
9555
>>> import torch_xla.core.xla_model as xm
96-
>>> import torch_xla.core.xla_builder as xb
9756
>>>
9857
>>> device = xm.xla_device()
9958
>>>
100-
>>> lower = torch.tensor([2], dtype=torch.int32, device=device)
101-
>>> upper = torch.tensor([52], dtype=torch.int32, device=device)
102-
>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
103-
>>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
59+
>>> init_val = torch.tensor(1, device=device)
60+
>>> iteri = torch.tensor(50, device=device)
10461
>>>
105-
>>> def body_fun(*argus):
106-
... plus_value, init_val = argus
107-
... return plus_value, torch.add(plus_value, init_val)
62+
>>> while iteri > 0:
63+
... init_val = init_val + 1
64+
... iteri -= 1
10865
...
109-
>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
110-
>>> res_
111-
tensor([51], device='xla:0', dtype=torch.int32)
66+
>>> init_val
67+
tensor(51, device='xla:0')
11268
```
11369
114-
For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3
70+
71+
72+
PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape

test/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ function run_xla_op_tests1 {
203203
function run_xla_op_tests2 {
204204
run_downcast_bf16 "$CDIR/test_data_type.py"
205205
run_test "$CDIR/pjrt/test_dtypes.py"
206-
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
206+
run_test "$CDIR/test_while_loop.py"
207207
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
208208
}
209209

test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py

Lines changed: 0 additions & 106 deletions
This file was deleted.

test/test_while_loop.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import os
2+
import unittest
3+
from typing import Callable, Dict, List
4+
5+
import torch
6+
import torch_xla
7+
# We need to import the underlying implementation function to register with the dispatcher
8+
import torch_xla.experimental.fori_loop
9+
from torch_xla.experimental.fori_loop import fori_loop
10+
from torch._higher_order_ops.while_loop import while_loop
11+
import torch_xla.core.xla_model as xm
12+
import torch_xla.core.xla_builder as xb
13+
import torch_xla.utils.utils as xu
14+
import torch.nn as nn
15+
import torch.nn.functional as F
16+
import torch.optim as optim
17+
18+
19+
def _fake_while_loop(cond_fn, body_fn, operands):
20+
# operands need to be more than one here
21+
while cond_fn(*operands):
22+
operands = body_fn(*operands)
23+
return operands
24+
25+
26+
class WhileLoopTest(unittest.TestCase):
27+
28+
def test_while_loop_addition(self):
29+
device = xm.xla_device()
30+
31+
def cond_fn(iteri, x):
32+
return iteri > 0
33+
34+
def body_fn(iteri, x):
35+
return iteri - 1, torch.add(x, 1)
36+
37+
init_val = torch.tensor(3, dtype=torch.int32, device=device)
38+
iteri = torch.tensor(10, device=device)
39+
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
40+
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
41+
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
42+
43+
def test_while_loop_addition_nested(self):
44+
device = xm.xla_device()
45+
46+
def cond_fn(iteri, x):
47+
return iteri > 0
48+
49+
def body_fn(iteri, x):
50+
return iteri - 1, torch.add(torch.add(x, 1), 1)
51+
52+
init_val = torch.tensor(2, dtype=torch.int32, device=device)
53+
iteri = torch.tensor(10, device=device)
54+
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
55+
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
56+
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
57+
58+
def test_while_loop_simple_linear_inside_loop(self):
59+
device = xm.xla_device()
60+
torch.set_grad_enabled(False)
61+
62+
class SimpleLinear(torch.nn.Module):
63+
64+
def __init__(self):
65+
super().__init__()
66+
self.linear = torch.nn.Linear(2, 2)
67+
68+
def forward(self, iteri, x):
69+
70+
def cond_fn(iteri, x):
71+
return iteri > 0
72+
73+
def body_fn(iteri, x):
74+
return iteri - 1, self.linear(x)
75+
76+
return while_loop(cond_fn, body_fn, (iteri, x))
77+
78+
def forward_without_while_loop_op(self, iteri, x):
79+
while (iteri > 0):
80+
x = self.linear(x)
81+
iteri -= 1
82+
return iteri, x
83+
84+
linear_model = SimpleLinear()
85+
linear_model.to(device)
86+
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
87+
iteri = torch.tensor(10, dtype=torch.int32, device=device)
88+
_, res_with_loop = linear_model(iteri, l_in_0)
89+
_, res_without_loop = linear_model.forward_without_while_loop_op(
90+
iteri, l_in_0)
91+
92+
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))
93+
94+
# ====== fori_loop ======
95+
@unittest.skip("Fori_loop is not supported now due to unstable result.")
96+
def test_fori_loop_addition(self):
97+
device = xm.xla_device()
98+
99+
lower = torch.tensor(0, device=device)
100+
upper = torch.tensor(50, device=device)
101+
init_val = torch.tensor(1, dtype=torch.int32, device=device)
102+
103+
def body_fun(x):
104+
return torch.add(x, 1)
105+
106+
_, res_with_loop = fori_loop(lower, upper, body_fun, (init_val))
107+
108+
# === expected ===
109+
for i in range(upper - lower):
110+
init_val = torch.add(init_val, 1)
111+
res_without_loop = init_val
112+
113+
114+
if __name__ == '__main__':
115+
test = unittest.main()
116+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/tpu/run_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ python3 test/dynamo/test_dynamo.py
2020
python3 test/spmd/test_spmd_debugging.py
2121
python3 test/pjrt/test_dtypes.py
2222
python3 test/pjrt/test_dynamic_plugin_tpu.py
23-
python3 test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py
23+
python3 test/test_while_loop.py
2424
python3 test/test_pallas.py
2525
python3 test/test_pallas_spmd.py
2626
python3 test/test_input_output_aliases.py

0 commit comments

Comments
 (0)