Skip to content

Commit bd95eb1

Browse files
authored
Define PJRT plugin interface in C++ (#6360)
1 parent 32d24ad commit bd95eb1

File tree

5 files changed

+74
-50
lines changed

5 files changed

+74
-50
lines changed

torch_xla/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import re
44
import tempfile
55

6+
import torch
7+
import _XLAC
68
from ._internal import tpu
79

810
logging.basicConfig()
@@ -135,12 +137,9 @@ def _setup_tpu_vm_library_path() -> bool:
135137
logger.setLevel(logging.INFO)
136138

137139
import atexit
138-
import torch
139140
from ._patched_functions import _apply_patches
140141
from .version import __version__
141142

142-
import _XLAC
143-
144143
_found_libtpu = _setup_tpu_vm_library_path()
145144

146145
# Setup Neuron library for AWS EC2 inf/trn instances.

torch_xla/csrc/init_python_bindings.cpp

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "pybind11/numpy.h"
2929
#include "pybind11/pybind11.h"
3030
#include "pybind11/pytypes.h"
31+
#include "pybind11/stl.h"
3132
#include "pybind11/stl_bind.h"
3233
#include "torch_xla/csrc/XLANativeFunctions.h"
3334
#include "torch_xla/csrc/aten_autograd_ops.h"
@@ -78,6 +79,28 @@ struct NoGilSection {
7879
PyThreadState* state = nullptr;
7980
};
8081

82+
class PyPjRtPlugin : public runtime::PjRtPlugin {
83+
public:
84+
using runtime::PjRtPlugin::PjRtPlugin;
85+
86+
std::string library_path() const override {
87+
PYBIND11_OVERRIDE_PURE(std::string, runtime::PjRtPlugin, library_path, );
88+
}
89+
90+
// Templates with commas confuse pybind's macros, so use an alias here
91+
// See https://github.com/pybind/pybind11/issues/2185#issuecomment-634005168
92+
using PjRtCreateOptions = std::unordered_map<std::string, xla::PjRtValueType>;
93+
const PjRtCreateOptions client_create_options() const override {
94+
PYBIND11_OVERRIDE_PURE(PjRtCreateOptions, runtime::PjRtPlugin,
95+
client_create_options, );
96+
}
97+
98+
bool requires_xla_coordinator() const override {
99+
PYBIND11_OVERRIDE_PURE(bool, runtime::PjRtPlugin,
100+
requires_xla_coordinator, );
101+
}
102+
};
103+
81104
c10::optional<torch::lazy::BackendDevice> GetOptionalDevice(
82105
const std::string& device_str) {
83106
if (device_str.empty()) {
@@ -2319,14 +2342,18 @@ void InitXlaModuleBindings(py::module m) {
23192342
return retlist;
23202343
});
23212344
// -------------Dynamo Integration API End-------------------------
2322-
m.def("_register_pjrt_plugin",
2323-
[](std::string name, std::string library_path,
2324-
std::unordered_map<std::string, xla::PjRtValueType> create_options,
2325-
bool init_coordinator) {
2326-
runtime::RegisterPjRtPlugin(
2327-
name, library_path,
2328-
{create_options.begin(), create_options.end()}, init_coordinator);
2329-
});
2345+
m.def(
2346+
"_register_pjrt_plugin",
2347+
[](std::string name, std::shared_ptr<const runtime::PjRtPlugin> plugin) {
2348+
runtime::RegisterPjRtPlugin(name, plugin);
2349+
});
2350+
py::class_<runtime::PjRtPlugin, PyPjRtPlugin,
2351+
std::shared_ptr<runtime::PjRtPlugin>>(m, "PjRtPlugin")
2352+
.def(py::init<>())
2353+
.def("library_path", &runtime::PjRtPlugin::library_path)
2354+
.def("client_create_options", &runtime::PjRtPlugin::client_create_options)
2355+
.def("requires_xla_coordinator",
2356+
&runtime::PjRtPlugin::requires_xla_coordinator);
23302357
}
23312358
} // namespace
23322359

torch_xla/csrc/runtime/pjrt_registry.cc

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include "torch_xla/csrc/runtime/pjrt_registry.h"
2+
13
#include "torch_xla/csrc/runtime/debug_macros.h"
24
#include "torch_xla/csrc/runtime/env_vars.h"
35
#include "torch_xla/csrc/runtime/profiler.h"
@@ -18,13 +20,8 @@ namespace runtime {
1820

1921
namespace {
2022

21-
struct PluginEntry {
22-
std::string library_path;
23-
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options;
24-
bool init_coordinator;
25-
};
26-
27-
std::unordered_map<std::string, PluginEntry> pjrt_plugins_;
23+
std::unordered_map<std::string, std::shared_ptr<const PjRtPlugin>>
24+
pjrt_plugins_;
2825

2926
xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
3027
auto allocator_config = xla::GpuAllocatorConfig{};
@@ -43,21 +40,18 @@ xla::GpuAllocatorConfig GetGpuAllocatorConfig() {
4340
return allocator_config;
4441
}
4542

46-
std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) {
43+
std::shared_ptr<const PjRtPlugin> GetPjRtPlugin(
44+
const std::string& device_type) {
4745
auto plugin_path = pjrt_plugins_.find(device_type);
48-
return plugin_path != pjrt_plugins_.end() ? std::optional(plugin_path->second)
49-
: std::nullopt;
46+
return plugin_path != pjrt_plugins_.end() ? plugin_path->second : nullptr;
5047
}
5148

5249
} // namespace
5350

54-
void RegisterPjRtPlugin(
55-
std::string name, std::string library_path,
56-
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options,
57-
bool init_coordinator) {
58-
TF_VLOG(3) << "Registering PjRt plugin " << name << " at " << library_path;
59-
pjrt_plugins_[name] = {std::move(library_path), std::move(create_options),
60-
init_coordinator};
51+
void RegisterPjRtPlugin(std::string name,
52+
std::shared_ptr<const PjRtPlugin> plugin) {
53+
TF_VLOG(3) << "Registering PjRt plugin " << name;
54+
pjrt_plugins_[name] = plugin;
6155
}
6256

6357
std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
@@ -66,12 +60,12 @@ InitializePjRt(const std::string& device_type) {
6660
std::unique_ptr<XlaCoordinator> coordinator;
6761

6862
if (sys_util::GetEnvBool(env::kEnvPjrtDynamicPlugins, false)) {
69-
std::optional<PluginEntry> plugin = GetPjRtPlugin(device_type);
63+
std::shared_ptr<const PjRtPlugin> plugin = GetPjRtPlugin(device_type);
7064
if (plugin) {
7165
TF_VLOG(1) << "Initializing client for PjRt plugin " << device_type;
7266

7367
std::shared_ptr<xla::KeyValueStoreInterface> kv_store = nullptr;
74-
if (plugin->init_coordinator) {
68+
if (plugin->requires_xla_coordinator()) {
7569
int local_process_rank = sys_util::GetEnvInt(
7670
env::kEnvPjRtLocalRank, sys_util::GetEnvInt("LOCAL_RANK", 0));
7771
int global_process_rank =
@@ -100,10 +94,12 @@ InitializePjRt(const std::string& device_type) {
10094
/*key_prefix=*/"pjrt:");
10195
}
10296
const PJRT_Api* c_api = *pjrt::LoadPjrtPlugin(
103-
absl::AsciiStrToLower(device_type), plugin->library_path);
97+
absl::AsciiStrToLower(device_type), plugin->library_path());
10498
XLA_CHECK_OK(pjrt::InitializePjrtPlugin(device_type));
105-
client = xla::GetCApiClient(absl::AsciiStrToUpper(device_type),
106-
plugin->create_options, kv_store)
99+
auto create_options = plugin->client_create_options();
100+
client = xla::GetCApiClient(
101+
absl::AsciiStrToUpper(device_type),
102+
{create_options.begin(), create_options.end()}, kv_store)
107103
.value();
108104
profiler::RegisterProfilerForPlugin(c_api);
109105
}

torch_xla/csrc/runtime/pjrt_registry.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
#ifndef XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_
22
#define XLA_CLIENT_INITIALIZE_PJRT_CLIENT_H_
33

4+
#include "torch_xla/csrc/runtime/xla_coordinator.h"
45
#include "xla/pjrt/pjrt_client.h"
6+
#include "xla/pjrt/pjrt_common.h"
57

68
namespace torch_xla {
79
namespace runtime {
810

9-
void RegisterPjRtPlugin(
10-
std::string name, std::string library_path,
11-
absl::flat_hash_map<std::string, xla::PjRtValueType> create_options = {},
12-
bool init_coordinator = true);
11+
class PjRtPlugin {
12+
public:
13+
virtual std::string library_path() const = 0;
14+
15+
virtual const std::unordered_map<std::string, xla::PjRtValueType>
16+
client_create_options() const = 0;
17+
18+
virtual bool requires_xla_coordinator() const = 0;
19+
};
20+
21+
void RegisterPjRtPlugin(std::string name,
22+
std::shared_ptr<const PjRtPlugin> plugin);
1323

1424
std::tuple<std::unique_ptr<xla::PjRtClient>, std::unique_ptr<XlaCoordinator>>
1525
InitializePjRt(const std::string& device_type);

torch_xla/experimental/plugins.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import torch_xla.utils.utils as xu
1616

1717

18-
class DevicePlugin:
19-
"""Base class for device plugings.
18+
class DevicePlugin(torch_xla._XLAC.PjRtPlugin):
19+
"""Base class for device plugins.
2020
2121
Default implementations assume a single device and local process.
2222
"""
@@ -62,6 +62,7 @@ def requires_xla_coordinator(self) -> bool:
6262
return False
6363

6464

65+
# TODO(wcromar): figure out if we can share this map with the C++ code.
6566
_plugin_registry = {}
6667

6768

@@ -84,21 +85,12 @@ def default() -> DevicePlugin:
8485

8586
def register_plugin(name: str, device_plugin: DevicePlugin):
8687
_plugin_registry[name.upper()] = device_plugin
87-
torch_xla._XLAC._register_pjrt_plugin(
88-
name, device_plugin.library_path(), device_plugin.client_create_options(),
89-
device_plugin.requires_xla_coordinator())
88+
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin)
9089

9190

9291
def register_installed_plugins():
9392
pjrt_entry_points = importlib_metadata.entry_points(group='torch_xla.plugins')
9493
for ep in pjrt_entry_points:
9594
device_plugin_class = ep.load()
9695

97-
# HACK: TpuPlugin raises EnvironmentError if libtpu is not installed.
98-
# TODO(wcromar): Decide if catching `EnvironmentError` is a permanent
99-
# behavior or temporary hack.
100-
try:
101-
register_plugin(ep.name.upper(), device_plugin_class())
102-
except EnvironmentError as e:
103-
logging.warning(
104-
"Failed to register plugin {}".format(ep.name), exc_info=e)
96+
register_plugin(ep.name.upper(), device_plugin_class())

0 commit comments

Comments
 (0)