Skip to content

Commit 6aeab30

Browse files
authored
Do not cache input args in dynamo bridge (#6553)
1 parent 19362a9 commit 6aeab30

File tree

3 files changed

+101
-15
lines changed

3 files changed

+101
-15
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import unittest
2+
3+
import torch
4+
import torch_xla
5+
import torch_xla.core.xla_model as xm
6+
from torch import nn
7+
from torch.utils._pytree import tree_map_only
8+
from torch_xla.core.dynamo_bridge import GraphInputMatcher
9+
10+
11+
class M(nn.Module):
12+
13+
def __init__(self):
14+
super().__init__()
15+
self.linear = nn.Linear(5, 3)
16+
17+
def forward(self, x):
18+
return self.linear(x)
19+
20+
def get_example_inputs(self):
21+
return (torch.rand(10, 5),)
22+
23+
24+
class TestGraphInputMatcher(unittest.TestCase):
25+
26+
def test_no_cache_fx_gragh_inputs(self):
27+
xla_dev = xm.xla_device()
28+
model = M().to(device=xla_dev)
29+
inputs = tree_map_only(torch.Tensor, lambda x: x.to(device=xla_dev),
30+
model.get_example_inputs())
31+
32+
xm.mark_step()
33+
args_tensor_ids = [
34+
torch_xla._XLAC._xla_get_tensor_id(xla_arg) for xla_arg in inputs
35+
]
36+
tensor_id_to_arg_idx = {
37+
tensor_id: i for i, tensor_id in enumerate(args_tensor_ids)
38+
}
39+
output = model(*inputs)
40+
xla_graph_hash = torch_xla._XLAC._get_graph_hash([output])
41+
(
42+
graph_input_tensor_ids,
43+
graph_input_xla_values,
44+
) = torch_xla._XLAC._get_tensors_xla_device_data_node([output])
45+
xla_args_tensor_ids = set(
46+
tree_map_only(torch.Tensor,
47+
lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
48+
inputs))
49+
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
50+
graph_input_tensor_ids,
51+
graph_input_xla_values,
52+
xla_args_tensor_ids)
53+
# The weight and bias are cached in GraphInputMatcher,
54+
# the model input will not be cached.
55+
self.assertEqual(graph_input_matcher.graph_input_xla_values.count(None), 1)
56+
57+
58+
if __name__ == '__main__':
59+
test = unittest.main()
60+
sys.exit(0 if test.result.wasSuccessful() else 1)

test/dynamo/test_num_output.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,14 @@ def do_test(self, model_class, expected_num_output):
7878
graph_input_tensor_ids,
7979
graph_input_xla_values,
8080
) = torch_xla._XLAC._get_tensors_xla_device_data_node(outputs)
81-
81+
xla_args_tensor_ids = set(
82+
tree_map_only(torch.Tensor,
83+
lambda input: torch_xla._XLAC._xla_get_tensor_id(input),
84+
inputs))
8285
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
8386
graph_input_tensor_ids,
84-
graph_input_xla_values)
87+
graph_input_xla_values,
88+
xla_args_tensor_ids)
8589
torch_xla._XLAC._xla_warm_up_cache(outputs, [])
8690

8791
def run_cached_graph(*inputs):

torch_xla/core/dynamo_bridge.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import torch._inductor
1717
from torch._inductor.fx_passes.post_grad import ConstructorMoverPass
1818

19+
from torch.utils import _pytree as pytree
20+
1921
import torch_xla
2022
import torch_xla.core.xla_model as xm
2123
import torch_xla.debug.metrics as metrics
@@ -33,19 +35,31 @@ class GraphInputMatcher:
3335
Specifically, those graph inputs corresponding to method parameters should be replaced with the
3436
arguments for the current call.
3537
36-
tensor_id_to_arg_idx maps the tensor id to the parameter index.
37-
graph_input_tensor_ids, graph_input_xla_values list the tensor_id and ivalue for each of the
38-
TS/XLA graph inputs.
38+
Args:
39+
tensor_id_to_arg_idx: Dict[int, int] - Maps the tensor id to the HLO parameter index.
40+
graph_input_tensor_ids: List[int] - tensor_id for each TS/XLA graph input.
41+
graph_input_xla_values: List[torch.Tensor] - ivalue for each TS/XLA graph input.
42+
Including both FX graph input tensors and weight tensors.
43+
xla_args_tensor_id: Set[int] - A set of tensor_ids for FX Graph inputs.
3944
"""
4045

41-
tensor_id_to_arg_idx: Dict[int, int]
42-
graph_input_tensor_ids: List[int]
43-
# there are 2 categories of graph_input_tensors.
44-
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
45-
# most likely const tensors and we can get its content from graph_input_tensors
46-
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
47-
# the tensor from method arguments
48-
graph_input_xla_values: List[Any]
46+
def __init__(self, tensor_id_to_arg_idx: Dict[int, int],
47+
graph_input_tensor_ids: List[int],
48+
graph_input_xla_values: List[torch.tensor],
49+
xla_args_tensor_id: Set[int]):
50+
self.tensor_id_to_arg_idx = tensor_id_to_arg_idx
51+
self.graph_input_tensor_ids = graph_input_tensor_ids
52+
# there are 2 categories of graph_input_tensors.
53+
# Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
54+
# most likely const tensors and we can get its content from graph_input_tensors
55+
# Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
56+
# the tensor from method arguments.
57+
# For category 2, beause user inputs will be used for each run, we do not
58+
# cache those tensors in GraphInputMatcher.
59+
self.graph_input_xla_values = [
60+
None if tensor_id in xla_args_tensor_id else xla_value for tensor_id,
61+
xla_value in zip(graph_input_tensor_ids, graph_input_xla_values)
62+
]
4963

5064
# get the real graph input tensors
5165
def __call__(self, args):
@@ -64,8 +78,10 @@ def __call__(self, args):
6478
xm.set_rng_state(
6579
(1012031 + inp.item() * 7012063) % 18446744073709551615, str_device)
6680
elif arg_idx is None:
81+
assert traced_xla_value is not None, "Traced Tensor cannot be None."
6782
inp = traced_xla_value
6883
else:
84+
assert traced_xla_value is None, "Graph input tensor should not be cached."
6985
inp = args[arg_idx]
7086
real_input.append(inp)
7187
return real_input
@@ -211,7 +227,13 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
211227

212228

213229
def extract_graph_helper(xla_model: torch.fx.GraphModule):
230+
# FX Graph inputs passed from Dynamo. xla_args are XLA Tensors.
214231
xla_args = xla_model.xla_args
232+
xla_args_tensor_ids = set(
233+
pytree.tree_map_only(
234+
torch.Tensor,
235+
lambda xla_arg: torch_xla._XLAC._xla_get_tensor_id(xla_arg),
236+
xla_args))
215237
assert all(
216238
map(
217239
is_xla_tensor,
@@ -304,8 +326,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
304326
), f"{len(graph_input_tensor_ids)} v.s. {len(graph_input_xla_values)}"
305327
graph_input_matcher = GraphInputMatcher(tensor_id_to_arg_idx,
306328
graph_input_tensor_ids,
307-
graph_input_xla_values)
308-
329+
graph_input_xla_values,
330+
xla_args_tensor_ids)
309331
# compiles and cache graph rooted at tensors in 'args_and_out'
310332
torch_xla._XLAC._xla_warm_up_cache(args_and_out, [])
311333

0 commit comments

Comments
 (0)