Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 37 additions & 79 deletions docs/fori_loop.md
Original file line number Diff line number Diff line change
@@ -1,114 +1,72 @@
# Fori_loop
`fori_loop` is a replacement of pure python for loop, PyTorch/XLA would enable `torch_xla.experimental.fori_loop` to keep loop computation graph as rolled during compilation
like [`jax.lax.fori_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html), not like currently repeat computations by enumerating all execution steps
of each iteration. `fori_loop` might help memory utilization and might help faster compilation.
# `While_loop` optimize memory utilization and compilation

User could use `fori_loop` like this:
```python
from torch_xla.experimental.fori_loop import fori_loop
res = fori_loop(upper, lower, /*user defined*/body_fun, init)
```

current fori_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-fori_loop) with `fori_loop` on TPU too.
<br>

For detailed implementation:
- for situation that loop range is dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`while_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#while_loop),
like [`jax.lax.while_loop`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html), PyTorch/XLA would support `while_loop` with the
native PyTorch and the XLA backend: XLA::While. Due to `while_loop` didn't support autograd, so it would be used for inference only.
### `while_loop`
`while_loop` replace pure python `while` loop, PyTorch supported `while_loop` by
[torch._higher_order_ops.while_loop](https://github.com/pytorch/pytorch/blob/62311257adb902d6a4ea98809c88895af1dbbf2b/torch/_higher_order_ops/while_loop.py#L66).
PyTorch/XLA provide experimental XLA backend support for `torch._higher_order_ops.while_loop` via `XLA::While`.

- for situation that loop range is not dynamic, [`fori_loop`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#fori_loop) is implemented with [`scan`](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#wipscan),
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` using XLA::While operator.
This implementation would be very similar like `while_loop`. `scan` support autograd, and it could be used in both training and inference.

# while_loop
`while_loop` is a replacement of pure python while loop, PyTorch has supported `while_loop` in
[code](https://github.com/pytorch/pytorch/blob/ca6a0e1348ba7dcade1833d983b1b4ca12a5c1e1/torch/_higher_order_ops/while_loop.py#L69).
PyTorch/XLA want to support `while_loop` with the native PyTorch and the XLA backend: XLA::While.

User could use `while_loop` like this:
#### Usage:
```python
import torch_xla.experimental.fori_loop
from torch._higher_order_ops.while_loop import while_loop
res = while_loop(/*user-defined*/cond_fn, /*user-defined*/body_fn, /*tuple or list*/init)
result = while_loop(cond_fn, body_fn, init)
```
current while_loop only support simple test like [link](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py), and user could try [simple user guide](https://github.com/pytorch/xla/blob/ManfeiBai-patch-81/docs/fori_loop.md#simple-example-with-while_loop) with `while_loop` on TPU too.

- `cond_fn`: User-defined condition function.
- `body_fn`: User-defined loop body function.
- `init`: Initial values (tuple or list).

# [WIP]scan
like [`jax.lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html), PyTorch/XLA would enable `scan` for training and inference since it support autograd.
`scan` is WIP.


# Simple user guide
User could try these three simple test case to better compare difference between `pure python for loop` and `fori_loop` and `while_loop`, these three test case have similar logic: cumulative plus 1 for ten times:

### simple example with pure python for loop
```bash
# python
>>> import torch
>>> init = torch.tensor([0], dtype=torch.int32)
>>> one_value = torch.ones(1, dtype=torch.int32)
>>>
>>> for i in range(10):
... init = init + one_value
...
>>> init
tensor([10], dtype=torch.int32)
```

### simple example with `while_loop`:
#### simple example with `while_loop`:
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> def cond_fn(init, limit_value):
... return limit_value[0] >= init[0]
>>> def cond_fn(iteri, x):
... return iteri > 0
...
>>> def body_fn(init, limit_value):
... one_value = torch.ones(1, dtype=torch.int32, device=device)
... return (torch.add(init, one_value), limit_value.clone())
>>> def body_fn(iteri, x):
... return iteri - 1, torch.add(x, 1)
...
>>> init = torch.tensor([0], dtype=torch.int32, device=device)
>>> limit_value = torch.tensor([10], dtype=torch.int32, device=device)
>>> res_, limit_value_ = while_loop(cond_fn, body_fn, (init, limit_value))
>>> res_
>>> init_val = torch.tensor(3, device=device)
>>> iteri = torch.tensor(10, device=device)
>>> _, res = while_loop(cond_fn, body_fn, (iteri, init_val))
>>> res
FunctionalTensor(lvl=0, value=\
tensor([11], device='xla:0', dtype=torch.int32))
tensor(13, device='xla:0'))
```

### simple example with `fori_loop`:
<br>

## Control group test case
For better compare difference between `pure python while loop` and `while_loop`, there is one test case called pure python `while` loop with similar logic: cumulative plus 1 for ten times:

### Control group example with pure python `while` loop
```bash
# PJRT_DEVICE=TPU python
>>> import torch
>>> import torch_xla
>>> import torch_xla.experimental.fori_loop
>>> from torch_xla.experimental.fori_loop import fori_loop
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>> import torch_xla.core.xla_builder as xb
>>>
>>> device = xm.xla_device()
>>>
>>> lower = torch.tensor([2], dtype=torch.int32, device=device)
>>> upper = torch.tensor([52], dtype=torch.int32, device=device)
>>> plus_value = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor([1], dtype=torch.int32, device=device)
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
>>>
>>> def body_fun(*argus):
... plus_value, init_val = argus
... return plus_value, torch.add(plus_value, init_val)
>>> while iteri > 0:
... init_val = init_val + 1
... iteri -= 1
...
>>> _, _, _, res_ = fori_loop(upper, lower, body_fun, plus_value, init_val)
>>> res_
tensor([51], device='xla:0', dtype=torch.int32)
>>> init_val
tensor(51, device='xla:0')
```

For more example and detailed user guide, please read [this test file](https://github.com/pytorch/xla/blob/master/test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py). PyTorch/XLA would include `while_loop` support in 2.3 for simple test case, complex test case and support for `fori_loop` and `scan` would be added after 2.3


PyTorch/XLA would include `while_loop` support in 2.4 with test case, support for `fori_loop` would be added after 2.4. For `while_loop`, currently we only should force define `body_fn` with same `input` and `output(return args)` shape
2 changes: 1 addition & 1 deletion test/run_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ function run_xla_op_tests1 {
function run_xla_op_tests2 {
run_downcast_bf16 "$CDIR/test_data_type.py"
run_test "$CDIR/pjrt/test_dtypes.py"
run_test "$CDIR/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py"
run_test "$CDIR/test_while_loop.py"
run_test "$CDIR/test_autocast.py" # TODO(yeounoh) this is expensive on GPU
}

Expand Down
4 changes: 2 additions & 2 deletions test/spmd/test_sharding_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@

num_devices = xr.global_runtime_device_count()

assert np.product(dcn_parallelism) * np.product(
assert np.prod(dcn_parallelism) * np.prod(
ici_parallelism) == num_devices, f"Number of devices {num_devices} \
does not match the product of the parallelism {np.product(dcn_parallelism) * np.product(ici_parallelism)}"
does not match the product of the parallelism {np.prod(dcn_parallelism) * np.prod(ici_parallelism)}"

# Use HybridMesh to optimize multislice topology
mesh = xs.HybridMesh(
Expand Down
106 changes: 0 additions & 106 deletions test/test_fori_loop_with_while_loop_simple_add_dispatch_in_torch.py

This file was deleted.

116 changes: 116 additions & 0 deletions test/test_while_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import unittest
from typing import Callable, Dict, List

import torch
import torch_xla
# We need to import the underlying implementation function to register with the dispatcher
import torch_xla.experimental.fori_loop
from torch_xla.experimental.fori_loop import fori_loop
from torch._higher_order_ops.while_loop import while_loop
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_builder as xb
import torch_xla.utils.utils as xu
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


def _fake_while_loop(cond_fn, body_fn, operands):
# operands need to be more than one here
while cond_fn(*operands):
operands = body_fn(*operands)
return operands


class WhileLoopTest(unittest.TestCase):

def test_while_loop_addition(self):
device = xm.xla_device()

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, torch.add(x, 1)

init_val = torch.tensor(3, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

def test_while_loop_addition_nested(self):
device = xm.xla_device()

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, torch.add(torch.add(x, 1), 1)

init_val = torch.tensor(2, dtype=torch.int32, device=device)
iteri = torch.tensor(10, device=device)
_, res_with_loop = while_loop(cond_fn, body_fn, (iteri, init_val))
_, res_without_loop = _fake_while_loop(cond_fn, body_fn, (iteri, init_val))
self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

def test_while_loop_simple_linear_inside_loop(self):
device = xm.xla_device()
torch.set_grad_enabled(False)

class SimpleLinear(torch.nn.Module):

def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)

def forward(self, iteri, x):

def cond_fn(iteri, x):
return iteri > 0

def body_fn(iteri, x):
return iteri - 1, self.linear(x)

return while_loop(cond_fn, body_fn, (iteri, x))

def forward_without_while_loop_op(self, iteri, x):
while (iteri > 0):
x = self.linear(x)
iteri -= 1
return iteri, x

linear_model = SimpleLinear()
linear_model.to(device)
l_in_0 = torch.randn(2, 2, dtype=torch.float32, device=device)
iteri = torch.tensor(10, dtype=torch.int32, device=device)
_, res_with_loop = linear_model(iteri, l_in_0)
_, res_without_loop = linear_model.forward_without_while_loop_op(
iteri, l_in_0)

self.assertTrue(torch.all(torch.eq(res_with_loop, res_without_loop)))

# ====== fori_loop ======
@unittest.skip("Fori_loop is not supported now due to unstable result.")
def test_fori_loop_addition(self):
device = xm.xla_device()

lower = torch.tensor(0, device=device)
upper = torch.tensor(50, device=device)
init_val = torch.tensor(1, dtype=torch.int32, device=device)

def body_fun(x):
return torch.add(x, 1)

_, res_with_loop = fori_loop(lower, upper, body_fun, (init_val))

# === expected ===
for i in range(upper - lower):
init_val = torch.add(init_val, 1)
res_without_loop = init_val


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
Loading