Skip to content

Commit 811eeef

Browse files
committed
comments
1 parent bc8319e commit 811eeef

File tree

1 file changed

+35
-30
lines changed

1 file changed

+35
-30
lines changed

torchax/test/test_interop.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,6 @@
44
from torchax import interop
55
import torchax
66

7-
class M1(torch.nn.Module):
8-
9-
def __init__(self):
10-
super().__init__()
11-
self.x = torch.ones(10, 10)
12-
13-
class M(torch.nn.Module):
14-
15-
def __init__(self):
16-
super().__init__()
17-
self.a = torch.nn.Linear(100, 100)
18-
self.b = torch.nn.Parameter(
19-
torch.ones(10, 10)
20-
)
21-
c = torch.ones(10, 10)
22-
self.register_buffer('c', c)
23-
self.register_buffer('c2', c, persistent=False)
24-
self.d = torch.ones(10, 10)
25-
self.m1 = M1()
26-
27-
287

298
class InteropTest(unittest.TestCase):
309

@@ -33,7 +12,28 @@ def setUp(self):
3312

3413

3514
def test_mod_attr(self):
36-
m = M()
15+
16+
class Child(torch.nn.Module):
17+
18+
def __init__(self):
19+
super().__init__()
20+
self.x = torch.ones(10, 10)
21+
22+
class ModWithUnregisteredTensor(torch.nn.Module):
23+
24+
def __init__(self):
25+
super().__init__()
26+
self.a = torch.nn.Linear(100, 100)
27+
self.b = torch.nn.Parameter(
28+
torch.ones(10, 10)
29+
)
30+
c = torch.ones(10, 10)
31+
self.register_buffer('c', c)
32+
self.register_buffer('c2', c, persistent=False)
33+
self.d = torch.ones(10, 10)
34+
self.m1 = Child()
35+
36+
m = ModWithUnregisteredTensor()
3737
params, buffers = interop.extract_all_buffers(m)
3838
self.assertEqual(
3939
set(params.keys()), {'a.weight', 'a.bias', 'b'}
@@ -49,7 +49,8 @@ def test_mod_attr(self):
4949

5050
def test_module_with_shared_weights(self):
5151

52-
class M2(torch.nn.Module):
52+
# arrange
53+
class Module(torch.nn.Module):
5354

5455
def __init__(self):
5556
super().__init__()
@@ -59,23 +60,25 @@ def __init__(self):
5960
def forward(self, x):
6061
return self.a(self.b(x))
6162

62-
m = M2().to('jax')
63+
m = Module().to('jax')
6364

6465
m_jitted = interop.JittableModule(m, dedup_parameters=True)
6566

66-
6767
# a's weights and bias and b's weights and bias
6868
self.assertEqual(len(m.state_dict()), 4)
6969

7070
# b's weights and bias are deduped
7171
self.assertEqual(len(m_jitted.params), 2)
72-
7372
x = torch.randn(10, 10).to('jax')
74-
7573
expected = m(x)
74+
75+
# act
76+
actual = m_jitted(x)
7677

77-
torch.testing.assert_allclose(m_jitted(x), expected)
78+
# assert
79+
torch.testing.assert_allclose(actual, expected)
7880

81+
# arrange
7982
# make sure buffer donation works
8083
functional_forward = interop.jax_jit(
8184
functools.partial(m_jitted.functional_call, 'forward'),
@@ -84,8 +87,10 @@ def forward(self, x):
8487
}
8588
)
8689

87-
torch.testing.assert_allclose(
88-
functional_forward(m_jitted.params, m_jitted.buffers, x) , expected)
90+
# act
91+
actual = functional_forward(m_jitted.params, m_jitted.buffers, x)
92+
# assert
93+
torch.testing.assert_allclose(actual, expected)
8994

9095

9196

0 commit comments

Comments
 (0)