Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions third_party/xla_client/env_vars.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ const char* const kEnvPjRtTpuMaxInflightComputations =
const char* const kEnvPjrtAsyncCpuClient = "PJRT_CPU_ASYNC_CLIENT";
const char* const kEnvPjrtAsyncGpuClient = "PJRT_GPU_ASYNC_CLIENT";
const char* const kEnvTpuLibraryPath = "TPU_LIBRARY_PATH";
const char* const kEnvXpuLibraryPath = "XPU_LIBRARY_PATH";
const char* const kEnvPjrtDistServiceAddr = "PJRT_DIST_SERVICE_ADDR";
const char* const kEnvPjRtLocalRank = "PJRT_LOCAL_PROCESS_RANK";

Expand Down
1 change: 1 addition & 0 deletions third_party/xla_client/env_vars.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ extern const char* const kEnvPjRtTpuMaxInflightComputations;
extern const char* const kEnvPjrtAsyncCpuClient;
extern const char* const kEnvPjrtAsyncGpuClient;
extern const char* const kEnvTpuLibraryPath;
extern const char* const kEnvXpuLibraryPath;
extern const char* const kEnvPjrtDistServiceAddr;
extern const char* const kEnvPjRtLocalRank;

Expand Down
7 changes: 7 additions & 0 deletions third_party/xla_client/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ PjRtComputationClient::PjRtComputationClient() {
/*distributed_client=*/distributed_client,
/*node_id=*/local_rank, allowed_devices = allowed_devices)
.value());
} else if (device_type == "XPU") {
TF_VLOG(1) << "Initializing PjRt XPU client...";
XLA_CHECK_OK(pjrt::LoadPjrtPlugin(
"xpu", sys_util::GetEnvString(env::kEnvXpuLibraryPath, "libxpu.so")));
supports_logical_on_device_shape_ = false;
client_ = std::move(xla::GetCApiClient("XPU").value());

} else {
XLA_ERROR() << absl::StrFormat("Unknown %s '%s'", env::kEnvPjRtDevice,
device_type);
Expand Down
12 changes: 6 additions & 6 deletions torch_xla/core/xla_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def is_xla_tensor(tensor):


def parse_xla_device(device):
m = re.match(r'(CPU|TPU|GPU):(\d+)$', device)
m = re.match(r'(CPU|TPU|GPU|XPU):(\d+)$', device)
if m:
return (m.group(1), int(m.group(2)))

Expand All @@ -126,7 +126,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
"""Returns a list of supported devices of a given kind.

Args:
devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU` or `CPU`
(the 'GPU' XLA device is currently not implemented).
max_devices (int, optional): The maximum number of devices to be returned of
that kind.
Expand All @@ -135,7 +135,7 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
The list of device strings.
"""
xla_devices = _DEVICES.value
devkind = [devkind] if devkind else ['TPU', 'GPU', 'CPU']
devkind = [devkind] if devkind else ['TPU', 'GPU', 'XPU', 'CPU']
for kind in devkind:
kind_devices = []
for i, device in enumerate(xla_devices):
Expand Down Expand Up @@ -231,7 +231,7 @@ def xla_device(n=None, devkind=None):
n (int, optional): The specific instance (ordinal) to be returned. If
specified, the specific XLA device instance will be returned. Otherwise
the first device of `devkind` will be returned.
devkind (string..., optional): If specified, one of `TPU`, `GPU` or `CPU`.
devkind (string..., optional): If specified, one of `TPU`, `GPU`, `XPU` or `CPU`.

Returns:
A `torch.device` with the requested instance.
Expand Down Expand Up @@ -279,7 +279,7 @@ def xla_device_hw(device):
real device.

Returns:
A string representation of the hardware type (`CPU`, `TPU`, `GPU`) of the
A string representation of the hardware type (`CPU`, `TPU`, `XPU`, `GPU`) of the
given device.
"""
real_device = _xla_real_device(device)
Expand Down Expand Up @@ -677,7 +677,7 @@ def all_gather(value, dim=0, groups=None, output=None, pin_layout=True):
participating replicas.
"""
if pin_layout and xla_device_hw(
value.device) in ('TPU', 'GPU') and output == None:
value.device) in ('TPU', 'GPU', 'XPU') and output == None:
# There is not an easy way to pin the all_gather layout on TPU and GPU, use
# all_reduce based all_gather for this purpose.
return _all_gather_using_all_reduce(
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ std::string XlaDeviceTypeToString(XlaDeviceType hw_type) {
return "GPU";
case XlaDeviceType::TPU:
return "TPU";
case XlaDeviceType::XPU:
return "XPU";
case XlaDeviceType::SPMD:
return "SPMD";
}
Expand Down Expand Up @@ -64,6 +66,9 @@ torch::lazy::BackendDevice ParseDeviceString(const std::string& device_spec) {
} else if (device_spec_parts[0] == "GPU") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::GPU);
} else if (device_spec_parts[0] == "XPU") {
device_type->type =
static_cast<std::underlying_type_t<XlaDeviceType>>(XlaDeviceType::XPU);
} else {
XLA_ERROR() << "Invalid device specification: " << device_spec;
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace torch_xla {
// TODO(yeounoh) `SPMD` is a virtual device that defers data `TransferToServer`
// until after the paritioning pass. This avoids transfering the full input
// tensor to the device.
enum class XlaDeviceType { CPU, GPU, TPU, SPMD };
enum class XlaDeviceType { CPU, GPU, TPU, XPU, SPMD };

struct DeviceType : public torch::lazy::BackendDeviceType {
DeviceType() { type = static_cast<int>(XlaDeviceType::CPU); }
Expand Down