1616import torch ._inductor
1717from torch ._inductor .fx_passes .post_grad import ConstructorMoverPass
1818
19+ from torch .utils import _pytree as pytree
20+
1921import torch_xla
2022import torch_xla .core .xla_model as xm
2123import torch_xla .debug .metrics as metrics
@@ -33,19 +35,31 @@ class GraphInputMatcher:
3335 Specifically, those graph inputs corresponding to method parameters should be replaced with the
3436 arguments for the current call.
3537
36- tensor_id_to_arg_idx maps the tensor id to the parameter index.
37- graph_input_tensor_ids, graph_input_xla_values list the tensor_id and ivalue for each of the
38- TS/XLA graph inputs.
38+ Args:
39+ tensor_id_to_arg_idx: Dict[int, int] - Maps the tensor id to the HLO parameter index.
40+ graph_input_tensor_ids: List[int] - tensor_id for each TS/XLA graph input.
41+ graph_input_xla_values: List[torch.Tensor] - ivalue for each TS/XLA graph input.
42+ Including both FX graph input tensors and weight tensors.
43+ xla_args_tensor_id: Set[int] - A set of tensor_ids for FX Graph inputs.
3944 """
4045
41- tensor_id_to_arg_idx : Dict [int , int ]
42- graph_input_tensor_ids : List [int ]
43- # there are 2 categories of graph_input_tensors.
44- # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
45- # most likely const tensors and we can get its content from graph_input_tensors
46- # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
47- # the tensor from method arguments
48- graph_input_xla_values : List [Any ]
46+ def __init__ (self , tensor_id_to_arg_idx : Dict [int , int ],
47+ graph_input_tensor_ids : List [int ],
48+ graph_input_xla_values : List [torch .tensor ],
49+ xla_args_tensor_id : Set [int ]):
50+ self .tensor_id_to_arg_idx = tensor_id_to_arg_idx
51+ self .graph_input_tensor_ids = graph_input_tensor_ids
52+ # there are 2 categories of graph_input_tensors.
53+ # Category 1: those whose id are not found in tensor_id_to_arg_idx. These are
54+ # most likely const tensors and we can get its content from graph_input_tensors
55+ # Category 2: those whose id are found in tensor_id_to_arg_idx. We should get
56+ # the tensor from method arguments.
57+ # For category 2, beause user inputs will be used for each run, we do not
58+ # cache those tensors in GraphInputMatcher.
59+ self .graph_input_xla_values = [
60+ None if tensor_id in xla_args_tensor_id else xla_value for tensor_id ,
61+ xla_value in zip (graph_input_tensor_ids , graph_input_xla_values )
62+ ]
4963
5064 # get the real graph input tensors
5165 def __call__ (self , args ):
@@ -64,8 +78,10 @@ def __call__(self, args):
6478 xm .set_rng_state (
6579 (1012031 + inp .item () * 7012063 ) % 18446744073709551615 , str_device )
6680 elif arg_idx is None :
81+ assert traced_xla_value is not None , "Traced Tensor cannot be None."
6782 inp = traced_xla_value
6883 else :
84+ assert traced_xla_value is None , "Graph input tensor should not be cached."
6985 inp = args [arg_idx ]
7086 real_input .append (inp )
7187 return real_input
@@ -211,7 +227,13 @@ def is_xla_tensor(tensor: torch.Tensor) -> bool:
211227
212228
213229def extract_graph_helper (xla_model : torch .fx .GraphModule ):
230+ # FX Graph inputs passed from Dynamo. xla_args are XLA Tensors.
214231 xla_args = xla_model .xla_args
232+ xla_args_tensor_ids = set (
233+ pytree .tree_map_only (
234+ torch .Tensor ,
235+ lambda xla_arg : torch_xla ._XLAC ._xla_get_tensor_id (xla_arg ),
236+ xla_args ))
215237 assert all (
216238 map (
217239 is_xla_tensor ,
@@ -304,8 +326,8 @@ def extract_graph_helper(xla_model: torch.fx.GraphModule):
304326 ), f"{ len (graph_input_tensor_ids )} v.s. { len (graph_input_xla_values )} "
305327 graph_input_matcher = GraphInputMatcher (tensor_id_to_arg_idx ,
306328 graph_input_tensor_ids ,
307- graph_input_xla_values )
308-
329+ graph_input_xla_values ,
330+ xla_args_tensor_ids )
309331 # compiles and cache graph rooted at tensors in 'args_and_out'
310332 torch_xla ._XLAC ._xla_warm_up_cache (args_and_out , [])
311333
0 commit comments