|
7 | 7 | import torch_xla.utils.utils as xu |
8 | 8 | import torch_xla.debug.metrics as met |
9 | 9 | from torch_xla import runtime as xr |
| 10 | +import torch_xla.debug.profiler as xp |
10 | 11 | import torch.optim as optim |
11 | 12 | import torch.nn as nn |
12 | 13 | import torch._dynamo as dynamo |
@@ -61,6 +62,50 @@ def test_random_op_different_result_each_run(self): |
61 | 62 | self.assertFalse(torch.allclose(dynamo_res_2, dynamo_res_3)) |
62 | 63 |
|
63 | 64 |
|
| 65 | +class DynamoLTCInteractionTest(unittest.TestCase): |
| 66 | + |
| 67 | + def index_copy_inplace(self, cache, update_indices, xk): |
| 68 | + cache.index_copy_(0, update_indices, xk) |
| 69 | + |
| 70 | + def test_mark_step_after_dynamo(self): |
| 71 | + cache_len = 512 |
| 72 | + kv_heads = 8 |
| 73 | + head_dim = 128 |
| 74 | + running = 16 |
| 75 | + |
| 76 | + device = xm.xla_device() |
| 77 | + cache = torch.rand((cache_len, kv_heads, head_dim)).to(device) |
| 78 | + update_indices = torch.randint( |
| 79 | + 0, cache_len, (running,), dtype=torch.long).to(device) |
| 80 | + xk = torch.rand((running, kv_heads, head_dim)).to(device) |
| 81 | + |
| 82 | + dynamo_index_copy_inplace = torch.compile( |
| 83 | + self.index_copy_inplace, backend="openxla", fullgraph=True) |
| 84 | + met.clear_all() |
| 85 | + for i in range(10): |
| 86 | + dynamo_index_copy_inplace(cache, update_indices, xk) |
| 87 | + xm.wait_device_ops() |
| 88 | + current_execute_time = met.metric_data('ExecuteTime')[0] |
| 89 | + # This mark_step should be a no-op and don't trigger additional execution. |
| 90 | + xm.mark_step() |
| 91 | + xm.wait_device_ops() |
| 92 | + self.assertEqual(current_execute_time, met.metric_data('ExecuteTime')[0]) |
| 93 | + |
| 94 | + |
| 95 | +class DynamoProfilerTest(unittest.TestCase): |
| 96 | + |
| 97 | + def dummy_fn(self, a): |
| 98 | + return torch.sin(a) + a |
| 99 | + |
| 100 | + def test_dynamo_with_trace(self): |
| 101 | + dynamo_dummy = torch.compile( |
| 102 | + self.dummy_fn, backend="openxla", fullgraph=True) |
| 103 | + t = torch.randn(2, 3, 4, device=xm.xla_device()) |
| 104 | + for i in range(10): |
| 105 | + with xp.Trace('build_graph'): |
| 106 | + t = dynamo_dummy(t) |
| 107 | + |
| 108 | + |
64 | 109 | class DynamoInferenceBasicTest(unittest.TestCase): |
65 | 110 |
|
66 | 111 | @classmethod |
|
0 commit comments