44from torchax import interop
55import torchax
66
7- class M1 (torch .nn .Module ):
8-
9- def __init__ (self ):
10- super ().__init__ ()
11- self .x = torch .ones (10 , 10 )
12-
13- class M (torch .nn .Module ):
14-
15- def __init__ (self ):
16- super ().__init__ ()
17- self .a = torch .nn .Linear (100 , 100 )
18- self .b = torch .nn .Parameter (
19- torch .ones (10 , 10 )
20- )
21- c = torch .ones (10 , 10 )
22- self .register_buffer ('c' , c )
23- self .register_buffer ('c2' , c , persistent = False )
24- self .d = torch .ones (10 , 10 )
25- self .m1 = M1 ()
26-
27-
287
298class InteropTest (unittest .TestCase ):
309
@@ -33,7 +12,28 @@ def setUp(self):
3312
3413
3514 def test_mod_attr (self ):
36- m = M ()
15+
16+ class Child (torch .nn .Module ):
17+
18+ def __init__ (self ):
19+ super ().__init__ ()
20+ self .x = torch .ones (10 , 10 )
21+
22+ class ModWithUnregisteredTensor (torch .nn .Module ):
23+
24+ def __init__ (self ):
25+ super ().__init__ ()
26+ self .a = torch .nn .Linear (100 , 100 )
27+ self .b = torch .nn .Parameter (
28+ torch .ones (10 , 10 )
29+ )
30+ c = torch .ones (10 , 10 )
31+ self .register_buffer ('c' , c )
32+ self .register_buffer ('c2' , c , persistent = False )
33+ self .d = torch .ones (10 , 10 )
34+ self .m1 = Child ()
35+
36+ m = ModWithUnregisteredTensor ()
3737 params , buffers = interop .extract_all_buffers (m )
3838 self .assertEqual (
3939 set (params .keys ()), {'a.weight' , 'a.bias' , 'b' }
@@ -49,7 +49,8 @@ def test_mod_attr(self):
4949
5050 def test_module_with_shared_weights (self ):
5151
52- class M2 (torch .nn .Module ):
52+ # arrange
53+ class Module (torch .nn .Module ):
5354
5455 def __init__ (self ):
5556 super ().__init__ ()
@@ -59,23 +60,25 @@ def __init__(self):
5960 def forward (self , x ):
6061 return self .a (self .b (x ))
6162
62- m = M2 ().to ('jax' )
63+ m = Module ().to ('jax' )
6364
6465 m_jitted = interop .JittableModule (m , dedup_parameters = True )
6566
66-
6767 # a's weights and bias and b's weights and bias
6868 self .assertEqual (len (m .state_dict ()), 4 )
6969
7070 # b's weights and bias are deduped
7171 self .assertEqual (len (m_jitted .params ), 2 )
72-
7372 x = torch .randn (10 , 10 ).to ('jax' )
74-
7573 expected = m (x )
74+
75+ # act
76+ actual = m_jitted (x )
7677
77- torch .testing .assert_allclose (m_jitted (x ), expected )
78+ # assert
79+ torch .testing .assert_allclose (actual , expected )
7880
81+ # arrange
7982 # make sure buffer donation works
8083 functional_forward = interop .jax_jit (
8184 functools .partial (m_jitted .functional_call , 'forward' ),
@@ -84,8 +87,10 @@ def forward(self, x):
8487 }
8588 )
8689
87- torch .testing .assert_allclose (
88- functional_forward (m_jitted .params , m_jitted .buffers , x ) , expected )
90+ # act
91+ actual = functional_forward (m_jitted .params , m_jitted .buffers , x )
92+ # assert
93+ torch .testing .assert_allclose (actual , expected )
8994
9095
9196
0 commit comments