Skip to content

Commit c96c95a

Browse files
authored
[Pallas] Introduce gmm_backward (#7151)
Summary: This pull request introduces a helper for gmm_backward. I'm still debuting if we need to make gmm as a autograd.function given we will do manual back-propagation in Mixtral. Test Plan: python test/test_gmm.py
1 parent ce1205e commit c96c95a

File tree

2 files changed

+40
-5
lines changed

2 files changed

+40
-5
lines changed

test/test_gmm.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch_xla
88
import torch_xla.core.xla_model as xm
99
import torch_xla.debug.metrics as met
10-
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm
10+
from torch_xla.experimental.custom_kernel import gmm, _make_group_metadata, _histogram, tgmm, gmm_backward
1111
from torch_xla import runtime as xr
1212
from torch_xla._internal import tpu
1313

@@ -344,6 +344,36 @@ def test_tgmm_bf16(self):
344344
# Make sure tgmm doesn't fallback.
345345
self.assertNotIn("aten::", met.short_metrics_report())
346346

347+
@unittest.skipIf(xr.device_type() != 'TPU', "This test only works on TPU.")
348+
def test_gmm_backward(self):
349+
self._init_test_cases()
350+
for test_case in self.tests_cases:
351+
num_groups = test_case['num_groups']
352+
k = test_case['k']
353+
m = test_case['m']
354+
n = test_case['n']
355+
lhs_dtype = rhs_dtype = torch.bfloat16
356+
357+
lhs = torch.rand(m, k, dtype=lhs_dtype, requires_grad=True)
358+
rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype, requires_grad=True)
359+
group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups)
360+
lhs.retain_grad()
361+
rhs.retain_grad()
362+
363+
ref_out = self._reference_gmm(lhs, rhs, group_sizes)
364+
ref_out.sum().backward()
365+
366+
ref_out_backward = torch.ones_like(ref_out)
367+
grad_lhs, grad_rhs = gmm_backward(
368+
ref_out_backward.to("xla"), lhs.to("xla"), rhs.to("xla"),
369+
group_sizes.to("xla"))
370+
371+
self.assertTrue(torch.allclose(lhs.grad, grad_lhs.cpu()))
372+
self.assertTrue(torch.allclose(rhs.grad, grad_rhs.cpu()))
373+
374+
# Make sure gmm doesn't fallback.
375+
self.assertNotIn("aten::", met.short_metrics_report())
376+
347377

348378
if __name__ == '__main__':
349379
logging.getLogger().setLevel(logging.INFO)

torch_xla/experimental/custom_kernel.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def tgmm(
771771
lhs: torch.Tensor,
772772
rhs: torch.Tensor,
773773
group_sizes: torch.Tensor,
774-
tiling: tuple[int, int, int] = (128, 128, 128)
774+
tiling: tuple[int, int, int] = (512, 512, 512)
775775
) -> torch.Tensor:
776776
"""Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :].
777777
@@ -813,13 +813,18 @@ def tgmm(
813813
)
814814
group_offset_torch = torch.tensor([0], dtype=torch.int32).to(lhs.device)
815815

816-
lhs = lhs.swapaxes(0, 1)
817816
return torch_xla._XLAC._xla_tpu_custom_call([
818-
num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch, lhs,
819-
rhs
817+
num_tiles, group_offsets, group_ids, m_tile_ids, group_offset_torch,
818+
lhs.t(), rhs
820819
], payload, [torch.Size([num_groups, k, n])], [preferred_element_type])[0]
821820

822821

822+
def gmm_backward(grad, lhs, rhs, group_sizes, tiling=(512, 512, 512)):
823+
grad_lhs = gmm(grad, rhs.transpose(-1, -2), group_sizes, tiling)
824+
grad_rhs = tgmm(lhs.t(), grad, group_sizes, tiling)
825+
return grad_lhs, grad_rhs
826+
827+
823828
def non_xla_attetion(q, k, v, attention_type):
824829
# This will be called when dynamo use fake tensor to construct the fake output.
825830
# We need to make sure output tensor's shape is correct.

0 commit comments

Comments
 (0)