Skip to content

Commit 22988c6

Browse files
committed
Add documentation and python API for persistent cache (#6046)
1 parent 45742d2 commit 22988c6

File tree

6 files changed

+78
-18
lines changed

6 files changed

+78
-18
lines changed

API_GUIDE.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,30 @@ tensors are always loaded back to the device they were saved from, and if
314314
that device is unavailable the load will fail. PyTorch/XLA, like all of PyTorch,
315315
is under active development and this behavior may change in the future.
316316

317+
## Compilation Caching
318+
319+
The XLA compiler converts the traced HLO into an executable which runs on
320+
the devices. Compilation can be time consuming, and in cases where the HLO
321+
doesn't change across executions, the compilation result can be persisted to
322+
disk for reuse, significantly reducing development iteration time.
323+
324+
Note that if the HLO changes between executions, a recompilation will still
325+
occur.
326+
327+
This is currently an experimental opt-in API, which must be activated before
328+
any computations are executed. Initialization is done through the
329+
`initialize_cache` API:
330+
331+
```python
332+
import torch_xla.runtime as xr
333+
xr.initialize_cache('YOUR_CACHE_PATH', readonly=False)
334+
```
335+
336+
This will initialize a persistent compilation cache at the specified path. The
337+
`readonly` parameter can be used to control whether the worker will be able to
338+
write to the cache, which can be useful when a shared cache mount is used for
339+
an SPMD workload.
340+
317341
## Further Reading
318342

319343
Additional documentation is available at the

test/test_persistent_cache.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,34 @@ def _assert_correctness_and_metrics(t, xt, metrics):
4343
f'Unexpected value for counter {counter}: expected {value}, got {actual}'
4444

4545

46-
def _mp_test(rank, metrics):
46+
def _mp_test(rank, tmpdir, metrics):
4747
# In MP, the cache dir must be different for each process to avoid a race
4848
# condition where one process loads the compilation result of another, which
4949
# would break the metrics assertion.
50-
os.environ['XLA_PERSISTENT_CACHE_PATH'] = \
51-
os.path.join(os.environ['XLA_PERSISTENT_CACHE_PATH'], str(rank))
50+
xr.initialize_cache(os.path.join(tmpdir, str(rank)))
5251

5352
t = torch.randn(16)
5453
xt = t.to(xm.xla_device())
5554
_assert_correctness_and_metrics(t, xt, metrics)
5655

5756

58-
def _single_device_test(metrics):
57+
def _single_device_test(tmpdir, metrics):
58+
xr.initialize_cache(tmpdir)
5959
t = torch.randn(16)
6060
xt = t.to(xm.xla_device())
6161
_assert_correctness_and_metrics(t, xt, metrics)
6262

6363

64-
def _spmd_replicated_test(metrics):
64+
def _spmd_replicated_test(tmpdir, metrics):
65+
xr.initialize_cache(tmpdir)
6566
xr.use_spmd()
6667
t = torch.randn(16)
6768
xt = t.to(xm.xla_device())
6869
_assert_correctness_and_metrics(t, xt, metrics)
6970

7071

71-
def _spmd_sharded_test(metrics):
72+
def _spmd_sharded_test(tmpdir, metrics):
73+
xr.initialize_cache(tmpdir)
7274
xr.use_spmd()
7375
t = torch.randn(16)
7476

@@ -90,19 +92,23 @@ class PersistentCacheTest(parameterized.TestCase):
9092

9193
@run_with_tmpdir
9294
def _run_test(self, launch_method, test_fn, tmpdir):
93-
os.environ['XLA_PERSISTENT_CACHE_PATH'] = tmpdir
94-
9595
# Run once to warm the cache
96-
launch_method(test_fn, ({
97-
'PersistentCacheMiss': 1,
98-
'PersistentCacheHit': None
99-
},))
96+
launch_method(test_fn, (
97+
tmpdir,
98+
{
99+
'PersistentCacheMiss': 1,
100+
'PersistentCacheHit': None
101+
},
102+
))
100103

101104
# The second run should hit the cache
102-
launch_method(test_fn, ({
103-
'PersistentCacheMiss': None,
104-
'PersistentCacheHit': 1
105-
},))
105+
launch_method(test_fn, (
106+
tmpdir,
107+
{
108+
'PersistentCacheMiss': None,
109+
'PersistentCacheHit': 1
110+
},
111+
))
106112

107113
def test_persistent_cache_mp(self):
108114
self._run_test(xmp.spawn, _mp_test)

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,6 +977,9 @@ void InitXlaModuleBindings(py::module m) {
977977
m.def("_xla_runtime_is_initialized", []() {
978978
return runtime::GetComputationClientIfInitialized() != nullptr;
979979
});
980+
m.def("_xla_computation_cache_is_initialized", []() {
981+
return XLAGraphExecutor::Get()->IsComputationCacheInitialized();
982+
});
980983
m.def("_get_git_revs", []() { return GetRevisions(); });
981984
m.def("_get_xla_tensor_dimension_size",
982985
[](const at::Tensor& tensor, int dim) {

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,15 @@ void XLAGraphExecutor::MaybeDumpGraph(std::string name,
506506
}
507507
}
508508

509+
bool XLAGraphExecutor::IsComputationCacheInitialized() {
510+
return computation_cache_ != nullptr;
511+
}
512+
509513
XLAGraphExecutor::ComputationCache* XLAGraphExecutor::GetComputationCache() {
510-
static ComputationCache* cache = CreateComputationCache();
511-
return cache;
514+
if (computation_cache_ == nullptr) {
515+
computation_cache_ = CreateComputationCache();
516+
}
517+
return computation_cache_;
512518
}
513519

514520
void XLAGraphExecutor::ClearPendingIrs(

torch_xla/csrc/xla_graph_executor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
173173
torch::lazy::HashReducer>;
174174

175175
ComputationCache* GetComputationCache();
176+
bool IsComputationCacheInitialized();
176177

177178
std::vector<torch::lazy::BackendDataPtr> ExecuteComputationWithBarrier(
178179
torch::lazy::hash_t hash, const std::vector<at::IValue>& graph_inputs,
@@ -344,6 +345,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
344345
std::shared_ptr<Async> SyncTensorsGraphInternal(
345346
std::vector<XLATensorPtr>* tensors, absl::Span<const std::string> devices,
346347
const SyncTensorsConfig& config, bool warm_up_cache_only = false);
348+
349+
ComputationCache* computation_cache_;
347350
};
348351

349352
} // namespace torch_xla

torch_xla/runtime.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,21 @@ def get_master_ip() -> str:
265265
if device_type() == 'TPU':
266266
return tpu.discover_master_worker_ip()
267267
raise RuntimeError(f'IP discovery not supported for device: {device_type()}')
268+
269+
270+
@requires_pjrt
271+
def initialize_cache(path: str, readonly: bool = False):
272+
"""Initializes the persistent compilation cache. This API must be called
273+
before any computations have been performed.
274+
275+
Args:
276+
path: The path at which to store the persistent cache.
277+
readonly: Whether or not this worker should have write access to the cache.
278+
"""
279+
assert not torch_xla._XLAC._xla_computation_cache_is_initialized(
280+
), "Computation cache has already been initialized"
281+
282+
# TODO(jonbolin): Consider moving away from environment variables to control
283+
# the cache.
284+
os.environ['XLA_PERSISTENT_CACHE_PATH'] = path
285+
os.environ['XLA_PERSISTENT_CACHE_READ_ONLY'] = '1' if readonly else '0'

0 commit comments

Comments
 (0)