1+ import functools
12import torch
23import unittest
34from torchax import interop
5+ import torchax
46
57class M1 (torch .nn .Module ):
68
@@ -23,8 +25,12 @@ def __init__(self):
2325 self .m1 = M1 ()
2426
2527
28+
2629class InteropTest (unittest .TestCase ):
2730
31+ def setUp (self ):
32+ torchax .enable_globally ()
33+
2834
2935 def test_mod_attr (self ):
3036 m = M ()
@@ -41,6 +47,49 @@ def test_mod_attr(self):
4147 self .assertEqual (m .a .weight .item (), 0 )
4248 self .assertEqual (m .m1 .x .item (), 0 )
4349
50+ def test_module_with_shared_weights (self ):
51+
52+ class M2 (torch .nn .Module ):
53+
54+ def __init__ (self ):
55+ super ().__init__ ()
56+ self .a = torch .nn .Linear (10 , 10 )
57+ self .b = self .a
58+
59+ def forward (self , x ):
60+ return self .a (self .b (x ))
61+
62+ m = M2 ().to ('jax' )
63+
64+ m_jitted = interop .JittableModule (m , dedup_parameters = True )
65+
66+
67+ # a's weights and bias and b's weights and bias
68+ self .assertEqual (len (m .state_dict ()), 4 )
69+
70+ # b's weights and bias are deduped
71+ self .assertEqual (len (m_jitted .params ), 2 )
72+
73+ x = torch .randn (10 , 10 ).to ('jax' )
74+
75+ expected = m (x )
76+
77+ torch .testing .assert_allclose (m_jitted (x ), expected )
78+
79+ # make sure buffer donation works
80+ functional_forward = interop .jax_jit (
81+ functools .partial (m_jitted .functional_call , 'forward' ),
82+ kwargs_for_jax_jit = {
83+ 'donate_argnums' : (0 , )
84+ }
85+ )
86+
87+ torch .testing .assert_allclose (
88+ functional_forward (m_jitted .params , m_jitted .buffers , x ) , expected )
89+
90+
91+
92+
4493
4594
4695if __name__ == '__main__' :
0 commit comments