Skip to content

Commit 2f8de95

Browse files
committed
Add 2 ops
1 parent 290426c commit 2f8de95

File tree

3 files changed

+53
-5
lines changed

3 files changed

+53
-5
lines changed

torchax/test/test_interop.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import functools
12
import torch
23
import unittest
34
from torchax import interop
5+
import torchax
46

57
class M1(torch.nn.Module):
68

@@ -23,8 +25,12 @@ def __init__(self):
2325
self.m1 = M1()
2426

2527

28+
2629
class InteropTest(unittest.TestCase):
2730

31+
def setUp(self):
32+
torchax.enable_globally()
33+
2834

2935
def test_mod_attr(self):
3036
m = M()
@@ -41,6 +47,49 @@ def test_mod_attr(self):
4147
self.assertEqual(m.a.weight.item(), 0)
4248
self.assertEqual(m.m1.x.item(), 0)
4349

50+
def test_module_with_shared_weights(self):
51+
52+
class M2(torch.nn.Module):
53+
54+
def __init__(self):
55+
super().__init__()
56+
self.a = torch.nn.Linear(10, 10)
57+
self.b = self.a
58+
59+
def forward(self, x):
60+
return self.a(self.b(x))
61+
62+
m = M2().to('jax')
63+
64+
m_jitted = interop.JittableModule(m, dedup_parameters=True)
65+
66+
67+
# a's weights and bias and b's weights and bias
68+
self.assertEqual(len(m.state_dict()), 4)
69+
70+
# b's weights and bias are deduped
71+
self.assertEqual(len(m_jitted.params), 2)
72+
73+
x = torch.randn(10, 10).to('jax')
74+
75+
expected = m(x)
76+
77+
torch.testing.assert_allclose(m_jitted(x), expected)
78+
79+
# make sure buffer donation works
80+
functional_forward = interop.jax_jit(
81+
functools.partial(m_jitted.functional_call, 'forward'),
82+
kwargs_for_jax_jit={
83+
'donate_argnums': (0, )
84+
}
85+
)
86+
87+
torch.testing.assert_allclose(
88+
functional_forward(m_jitted.params, m_jitted.buffers, x) , expected)
89+
90+
91+
92+
4493

4594

4695
if __name__ == '__main__':

torchax/torchax/interop.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def set_one(module, prefix):
5050

5151
class JittableModule(torch.nn.Module):
5252

53-
def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_paramters=True):
53+
def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_parameters=True):
5454
super().__init__()
5555
self.params, self.buffers = extract_all_buffers(m)
5656
self._model = m
@@ -60,7 +60,7 @@ def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_paramters=True):
6060

6161
self._extra_dumped_weights = {}
6262

63-
if dedup_paramters:
63+
if dedup_parameters:
6464
temp = collections.defaultdict(list)
6565
for k, v in self.params.items():
6666
temp[id(v)].append(k)
@@ -72,9 +72,6 @@ def __init__(self, m: torch.nn.Module, extra_jit_args={}, dedup_paramters=True):
7272
for extra_keys in v[1:]:
7373
del self.params[extra_keys]
7474

75-
76-
77-
7875
def __call__(self, *args, **kwargs):
7976
return self.forward(*args, **kwargs)
8077

torchax/torchax/ops/jaten.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,12 +1590,14 @@ def _aten_bitwise_not(self):
15901590

15911591

15921592
# aten.bitwise_left_shift
1593+
@op(torch.ops.aten.__lshift__)
15931594
@op(torch.ops.aten.bitwise_left_shift)
15941595
def _aten_bitwise_left_shift(input, other):
15951596
return jnp.left_shift(input, other)
15961597

15971598

15981599
# aten.bitwise_right_shift
1600+
@op(torch.ops.aten.__rshift__)
15991601
@op(torch.ops.aten.bitwise_right_shift)
16001602
def _aten_bitwise_right_shift(input, other):
16011603
return jnp.right_shift(input, other)

0 commit comments

Comments
 (0)