Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,21 @@ def test_sharded_tensor_aliasing(self):
xm.mark_step()
self.assertEqual(met.metric_data("InputOutputAliasCount")[0], 1)

def test_mark_sharding_ir_with_multiple_output(self):
partition_spec = (0,)
xt1 = torch.randn(8, 8).to(xm.xla_device())
# max return 2 tensors `value` and `indices`. They are the output
# of the same IR Node `MaxInDim`
(xt_val, xt_index) = torch.max(xt1, 1)
xst_val = xs.mark_sharding(xt_val, self._get_mesh((self.n_devices)),
partition_spec)
# `xst_val`` should have sharding spec now, but `xst_index` should not
self.assertNotEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_val), '')
self.assertEqual(torch_xla._XLAC._get_xla_sharding_spec(xt_index), '')
# xst_index's HLO should not have any sharding
self.assertNotIn('convert(s32[8]{0} %get-tuple-element.25), sharding',
torch_xla._XLAC._get_xla_tensors_hlo([xt_index]))


if __name__ == '__main__':
test = unittest.main()
Expand Down
83 changes: 47 additions & 36 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,14 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op,
return torch::lazy::HashCombine(h, hash_seed);
}

void XlaNode::SetSharding(const xla::OpSharding& sharding) {
output_sharding_ = std::make_shared<xla::OpSharding>(sharding);
sharding_hash_ = CreateShardingHash(output_sharding_, node_hash_);
void XlaNode::SetSharding(const xla::OpSharding& sharding, size_t index) {
if (output_shardings_.size() == 0) {
output_shardings_ =
std::vector<std::shared_ptr<xla::OpSharding>>(num_outputs());
}
output_shardings_[index] = std::make_shared<xla::OpSharding>(sharding);
// TODO(JackCaoG): fix this hashing
UpdateShardingHash();
}

xla::Shape XlaNode::GetOpShape(
Expand All @@ -179,40 +184,46 @@ const xla::Shape& GetXlaShape(const torch::lazy::Value& value) {

// The sharding hash is only based on relevant fields from the xla::OpSharding
// object. We skip the field that's irrelevant, which is the layout.
torch::lazy::hash_t XlaNode::CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed) {
torch::lazy::hash_t sharding_hash = hash_seed;
for (const auto& tile_assignment_dimension :
sharding->tile_assignment_dimensions()) {
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
sharding->tile_assignment_devices()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash,
(uint32_t)tile_assignment_device);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)last_tile_dim);
}
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)sharding->type());
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)sharding->replicate_on_last_tile_dim());

xla::ShapeProto shape_proto = sharding->tile_shape();
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)shape_proto.element_type());
for (const auto& dim : shape_proto.dimensions()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash, (uint32_t)dim);
}
for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)is_dyn_dim);
void XlaNode::UpdateShardingHash() {
sharding_hash_ = node_hash_;
for (size_t i = 0; i < output_shardings_.size(); i++) {
// keep the index as part of the hash
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)i);
std::shared_ptr<xla::OpSharding> sharding = output_shardings_[i];
// skip the hash compute for empty sharding
if (!sharding) {
continue;
}
for (const auto& tile_assignment_dimension :
sharding->tile_assignment_dimensions()) {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
sharding->tile_assignment_devices()) {
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)tile_assignment_device);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)last_tile_dim);
}
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)sharding->type());
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)sharding->replicate_on_last_tile_dim());

xla::ShapeProto shape_proto = sharding->tile_shape();
sharding_hash_ = torch::lazy::HashCombine(
sharding_hash_, (uint32_t)shape_proto.element_type());
for (const auto& dim : shape_proto.dimensions()) {
sharding_hash_ = torch::lazy::HashCombine(sharding_hash_, (uint32_t)dim);
}
for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) {
sharding_hash_ =
torch::lazy::HashCombine(sharding_hash_, (uint32_t)is_dyn_dim);
}
}

return sharding_hash;
}

} // namespace torch_xla
20 changes: 10 additions & 10 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,17 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t shardingHash() const { return sharding_hash_; }

// The node's outputs get assigned the same HLO sharding
// TODO: test multi-output example.
const std::shared_ptr<xla::OpSharding> GetSharding() const {
return output_sharding_;
const std::shared_ptr<xla::OpSharding> GetSharding(size_t index) const {
if (output_shardings_.size() == 0) {
return nullptr;
}
return output_shardings_[index];
}

void SetSharding(const xla::OpSharding& sharding);
void SetSharding(const xla::OpSharding& sharding, size_t index);

void ClearSharding() {
output_sharding_ = nullptr;
output_shardings_.clear();
sharding_hash_ = 0;
}

Expand All @@ -145,17 +147,15 @@ class XlaNode : public torch::lazy::Node {

static std::vector<torch::lazy::SourceLocation> GetFrameInfo();

static torch::lazy::hash_t CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed);
void UpdateShardingHash();

xla::Shape xla_shape_;
torch::lazy::hash_t node_hash_ = 0;
torch::lazy::hash_t dag_hash_;
torch::lazy::hash_t sharding_hash_ = 0;

// Experimental sharding annotation attached to the IR node.
// TODO(yeounoh): make sure that view update doesn't reset this.
std::shared_ptr<xla::OpSharding> output_sharding_ = nullptr;
// Experimental sharding annotations attached to the IR node.
std::vector<std::shared_ptr<xla::OpSharding>> output_shardings_;
};

inline std::ostream& operator<<(std::ostream& stream, const XlaNode& node) {
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ DeviceData::DeviceData(std::shared_ptr<torch::lazy::BackendData> data)
torch_xla::runtime::GetComputationClient()->GetDataSharding(
UnwrapXlaData(data_));
if (op_sharding.has_value()) {
SetSharding(op_sharding.value());
// DeviceData Node only has 1 output.
SetSharding(op_sharding.value(), 0);
}
}

Expand Down
13 changes: 7 additions & 6 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,10 @@ XLATensor::XLATensor(torch::lazy::Value ir_value,
// Preserve sharding if a new tensor is created from a sharded IR node.
if (CurrentIrValue()) {
auto* xla_node = dynamic_cast<XlaNode*>(CurrentIrValue().node.get());
if (xla_node->GetSharding()) {
if (xla_node->GetSharding(CurrentIrValue().index)) {
ShardingSpec sharding =
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()};
ShardingSpec{*xla_node->GetSharding(CurrentIrValue().index),
xla_node->xla_shape()};
SetShardingSpec(sharding);
}
}
Expand Down Expand Up @@ -239,7 +240,7 @@ void XLATensor::SetShardingSpec(const ShardingSpec& sharding) {
<< sharding.sharding.DebugString();
}
dynamic_cast<XlaNode*>(GetIrValue().node.get())
->SetSharding(sharding_spec()->sharding);
->SetSharding(sharding_spec()->sharding, GetIrValue().index);
}
void XLATensor::ClearShardingSpec() {
data()->sharding = nullptr;
Expand All @@ -256,10 +257,10 @@ XLATensor::ShardingSpecPtr XLATensor::sharding_spec() const {
if (sharding && ir_value) {
// The copy of sharding annotation on the IR node should be the same.
auto* xla_node = dynamic_cast<XlaNode*>(ir_value.node.get());
if (xla_node->GetSharding()) {
if (xla_node->GetSharding(ir_value.index)) {
XLA_CHECK(ShardingUtil::EqualShardingSpecs(
*sharding,
ShardingSpec{*xla_node->GetSharding(), xla_node->xla_shape()}));
*sharding, ShardingSpec{*xla_node->GetSharding(ir_value.index),
xla_node->xla_shape()}));
}
}
return sharding;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ XLAGraphExecutor::SyncTensorCollection XLAGraphExecutor::CollectSyncTensors(
XLATensor::ShardingSpecPtr sharding = tensors[i]->sharding_spec();
if (sharding) {
dynamic_cast<XlaNode*>(ir_value.node.get())
->SetSharding(sharding->sharding);
->SetSharding(sharding->sharding, ir_value.index);
}
}
} else if (config.force_ltc_data) {
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,9 @@ bool ShardingUtil::SetHloSharding(LoweringContext* lowering_ctx) {
const torch::lazy::Node* node = elem.first.node;
const XlaNode* xla_node = dynamic_cast<const XlaNode*>(node);
auto instruction = XlaBuilderFriend::GetInstruction(elem.second);
if (xla_node->GetSharding() != nullptr) {
*instruction->mutable_sharding() = *xla_node->GetSharding();
if (xla_node->GetSharding(elem.first.index) != nullptr) {
*instruction->mutable_sharding() =
*xla_node->GetSharding(elem.first.index);
is_sharded = true;
}
}
Expand Down