@@ -43,32 +43,34 @@ def _assert_correctness_and_metrics(t, xt, metrics):
4343 f'Unexpected value for counter { counter } : expected { value } , got { actual } '
4444
4545
46- def _mp_test (rank , metrics ):
46+ def _mp_test (rank , tmpdir , metrics ):
4747 # In MP, the cache dir must be different for each process to avoid a race
4848 # condition where one process loads the compilation result of another, which
4949 # would break the metrics assertion.
50- os .environ ['XLA_PERSISTENT_CACHE_PATH' ] = \
51- os .path .join (os .environ ['XLA_PERSISTENT_CACHE_PATH' ], str (rank ))
50+ xr .initialize_cache (os .path .join (tmpdir , str (rank )))
5251
5352 t = torch .randn (16 )
5453 xt = t .to (xm .xla_device ())
5554 _assert_correctness_and_metrics (t , xt , metrics )
5655
5756
58- def _single_device_test (metrics ):
57+ def _single_device_test (tmpdir , metrics ):
58+ xr .initialize_cache (tmpdir )
5959 t = torch .randn (16 )
6060 xt = t .to (xm .xla_device ())
6161 _assert_correctness_and_metrics (t , xt , metrics )
6262
6363
64- def _spmd_replicated_test (metrics ):
64+ def _spmd_replicated_test (tmpdir , metrics ):
65+ xr .initialize_cache (tmpdir )
6566 xr .use_spmd ()
6667 t = torch .randn (16 )
6768 xt = t .to (xm .xla_device ())
6869 _assert_correctness_and_metrics (t , xt , metrics )
6970
7071
71- def _spmd_sharded_test (metrics ):
72+ def _spmd_sharded_test (tmpdir , metrics ):
73+ xr .initialize_cache (tmpdir )
7274 xr .use_spmd ()
7375 t = torch .randn (16 )
7476
@@ -90,19 +92,23 @@ class PersistentCacheTest(parameterized.TestCase):
9092
9193 @run_with_tmpdir
9294 def _run_test (self , launch_method , test_fn , tmpdir ):
93- os .environ ['XLA_PERSISTENT_CACHE_PATH' ] = tmpdir
94-
9595 # Run once to warm the cache
96- launch_method (test_fn , ({
97- 'PersistentCacheMiss' : 1 ,
98- 'PersistentCacheHit' : None
99- },))
96+ launch_method (test_fn , (
97+ tmpdir ,
98+ {
99+ 'PersistentCacheMiss' : 1 ,
100+ 'PersistentCacheHit' : None
101+ },
102+ ))
100103
101104 # The second run should hit the cache
102- launch_method (test_fn , ({
103- 'PersistentCacheMiss' : None ,
104- 'PersistentCacheHit' : 1
105- },))
105+ launch_method (test_fn , (
106+ tmpdir ,
107+ {
108+ 'PersistentCacheMiss' : None ,
109+ 'PersistentCacheHit' : 1
110+ },
111+ ))
106112
107113 def test_persistent_cache_mp (self ):
108114 self ._run_test (xmp .spawn , _mp_test )
0 commit comments