|
7 | 7 | import torch_xla |
8 | 8 | import torch_xla.core.xla_model as xm |
9 | 9 | import torch_xla.debug.metrics as met |
10 | | -from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm |
| 10 | +from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward |
11 | 11 | from torch_xla import runtime as xr |
12 | 12 | from torch_xla._internal import tpu |
13 | 13 |
|
@@ -344,6 +344,36 @@ def test_tgmm_bf16(self): |
344 | 344 | # Make sure tgmm doesn't fallback. |
345 | 345 | self.assertNotIn("aten::", met.short_metrics_report()) |
346 | 346 |
|
| 347 | + @unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.") |
| 348 | + def test_gmm_backward(self): |
| 349 | + self._init_test_cases() |
| 350 | + for test_case in self.tests_cases: |
| 351 | + num_groups = test_case['num_groups'] |
| 352 | + k = test_case['k'] |
| 353 | + m = test_case['m'] |
| 354 | + n = test_case['n'] |
| 355 | + lhs_dtype = rhs_dtype = torch.bfloat16 |
| 356 | + |
| 357 | + lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True) |
| 358 | + rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True) |
| 359 | + group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) |
| 360 | + lhs.retain_grad() |
| 361 | + rhs.retain_grad() |
| 362 | + |
| 363 | + ref_out = self._reference_gmm(lhs, rhs, group_sizes) |
| 364 | + ref_out.sum().backward() |
| 365 | + |
| 366 | + ref_out_backward = torch.ones_like(ref_out) |
| 367 | + grad_lhs, grad_rhs = gmm_backward( |
| 368 | + ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"), |
| 369 | + group_sizes.to("xla")) |
| 370 | + |
| 371 | + self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu())) |
| 372 | + self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu())) |
| 373 | + |
| 374 | + # Make sure gmm doesn't fallback. |
| 375 | + self.assertNotIn("aten::", met.short_metrics_report()) |
| 376 | + |
347 | 377 |
|
348 | 378 | if __name__ == '__main__': |
349 | 379 | logging.getLogger().setLevel(logging.INFO) |
|
0 commit comments