@@ -101,33 +101,43 @@ def __init__(self, device=None):
101101 super ().__init__ ()
102102 self .self_tensor = torch .zeros ((5 , 3 ), device = device )
103103
104- def forward (self , index , copy_tensor , input_tensor ):
104+ def copy_ (self , index , copy_tensor ):
105105 self .self_tensor .index_copy_ (0 , index , copy_tensor )
106+
107+ def add_ (self , index , other_tensor ):
108+ self .self_tensor .add_ (other_tensor )
109+
110+ def abs_ (self , index , other_tensor ):
111+ self .self_tensor .abs_ ()
112+
113+ def forward (self , index , copy_tensor , input_tensor , op_name ):
114+ getattr (self , op_name )(index , copy_tensor )
106115 output = input_tensor + self .self_tensor
107116 return output
108117
109118 torch ._dynamo .reset ()
110119 met .clear_counters ()
111120 met .clear_all ()
112121 device = xm .xla_device ()
113- input_tensor = torch .ones (3 )
114- copy_tensor = torch .tensor ([[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 , 9 ]],
115- dtype = torch .float )
116- index = torch .tensor ([0 , 4 , 2 ])
117- xla_input_tensor = input_tensor .to (device )
118- xla_copy_tensor = copy_tensor .to (device )
119- xla_index = index .to (device )
120122
121123 cpu_model = TestModel ()
122- res_cpu = cpu_model .forward (index , copy_tensor , input_tensor )
123-
124124 xla_model = TestModel (device ).to (device )
125125 compiled_model = torch .compile (xla_model , backend = 'openxla' )
126- res_xla_dynamo = compiled_model .forward (xla_index , xla_copy_tensor ,
127- xla_input_tensor )
128126
129- self .assertIn ('xla::index_copy' , met .counter_names ())
130- self .assertTrue (torch .allclose (res_cpu , res_xla_dynamo .cpu ()))
127+ input_tensor = torch .ones (3 )
128+ copy_tensor = torch .rand (5 , 3 )
129+ index = torch .tensor ([0 , 4 , 2 , 1 , 3 ])
130+ xla_input_tensor = input_tensor .to (device )
131+ xla_copy_tensor = copy_tensor .to (device )
132+ xla_index = index .to (device )
133+
134+ in_place_ops = ['copy_' , 'add_' , 'abs_' ]
135+ for in_place_op in in_place_ops :
136+ res_cpu = cpu_model .forward (
137+ index , copy_tensor , input_tensor , op_name = in_place_op )
138+ res_xla_dynamo = compiled_model .forward (
139+ xla_index , xla_copy_tensor , xla_input_tensor , op_name = in_place_op )
140+ self .assertTrue (torch .allclose (res_cpu , res_xla_dynamo .cpu ()))
131141
132142 def test_simple_model_with_different_input_shape (self ):
133143 met .clear_counters ()
@@ -245,22 +255,22 @@ def fn_fallback(t):
245255 cpu_res = fn_fallback (t )
246256 xla_dynamo_res = dynamo_fn (t_xla )
247257 self .assertTrue (torch .allclose (cpu_res , xla_dynamo_res .cpu ()))
248- self .assertEqual (met .metric_data ('CompileTime' )[0 ], 4 )
249- self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 8 )
258+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], 3 )
259+ self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 10 )
250260
251261 # Second tracing
252262 met .clear_counters ()
253263 xla_dynamo_res_2 = dynamo_fn (t_xla )
254264 self .assertTrue (torch .allclose (cpu_res , xla_dynamo_res_2 .cpu ()))
255- self .assertEqual (met .metric_data ('CompileTime' )[0 ], 4 )
256- self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 10 )
265+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], 3 )
266+ self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 12 )
257267
258268 # Verify that dynamo can handle different inputs
259269 xla_dynamo_res_3 = dynamo_fn (t_xla * 3 )
260270 cpu_res_3 = fn_fallback (t * 3 )
261271 self .assertTrue (torch .allclose (cpu_res_3 , xla_dynamo_res_3 .cpu ()))
262- self .assertEqual (met .metric_data ('CompileTime' )[0 ], 5 )
263- self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 12 )
272+ self .assertEqual (met .metric_data ('CompileTime' )[0 ], 4 )
273+ self .assertEqual (met .metric_data ('ExecuteTime' )[0 ], 15 )
264274
265275
266276class DynamoTrainingBasicTest (unittest .TestCase ):
0 commit comments