Skip to content

Commit af14415

Browse files
Siyuan Liulsy323
authored andcommitted
Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path
1 parent 495f844 commit af14415

File tree

8 files changed

+85
-92
lines changed

8 files changed

+85
-92
lines changed

setup.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,8 @@ def run(self):
212212
extra_compile_args = []
213213
cxx_abi = os.getenv(
214214
'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
215-
experimental_dynamism = os.getenv(
216-
'EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None)
217215
if cxx_abi is not None:
218216
extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
219-
if experimental_dynamism is not None:
220-
extra_compile_args.append(
221-
f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}')
222217

223218

224219
class BazelExtension(Extension):

torch_xla/csrc/elementwise.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include "torch_xla/csrc/helpers.h"
66
#include "torch_xla/csrc/random.h"
77
#include "torch_xla/csrc/runtime/debug_macros.h"
8+
#include "torch_xla/csrc/runtime/sys_util.h"
89
#include "torch_xla/csrc/shape_helper.h"
910
#include "torch_xla/csrc/tensor_util.h"
1011
#include "torch_xla/csrc/xla_lower_util.h"
@@ -14,6 +15,9 @@
1415
namespace torch_xla {
1516
namespace {
1617

18+
static const bool experimental_unbounded_dynamism =
19+
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);
20+
1721
xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val,
1822
const at::Scalar& max_val) {
1923
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
@@ -66,16 +70,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output,
6670

6771
xla::XlaOp BuildRelu(xla::XlaOp input) {
6872
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
69-
#ifndef EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
70-
return xla::Max(input, XlaHelpers::ScalarValue<float>(
71-
0, input_shape.element_type(), input.builder()));
72-
#else
7373
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
7474
0, input_shape.element_type(), input.builder());
75-
auto promoted = XlaHelpers::Promote(input, scalar);
76-
77-
return xla::Max(promoted.first, promoted.second);
78-
#endif
75+
if (experimental_unbounded_dynamism) {
76+
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
77+
// TODO(lsy323): Remove this branch once the support is added in XLA.
78+
auto promoted = XlaHelpers::Promote(input, scalar);
79+
return xla::Max(promoted.first, promoted.second);
80+
} else {
81+
return xla::Max(input, scalar);
82+
}
7983
}
8084

8185
xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {

torch_xla/csrc/helpers.cpp

Lines changed: 21 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
namespace torch_xla {
2222
namespace {
2323

24-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
24+
static const bool experimental_unbounded_dynamism =
25+
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);
26+
27+
// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
2528
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();
26-
#endif
2729

2830
xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
2931
xla::XlaOp result) {
@@ -67,9 +69,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
6769
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
6870
for (size_t i = 0; i < dimensions.size(); ++i) {
6971
bcast_sizes.at(dimensions[i]) = sizes[i];
70-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
71-
XLA_CHECK(sizes[i] != kUnboundedSize);
72-
#endif
72+
if (experimental_unbounded_dynamism) {
73+
XLA_CHECK(sizes[i] != kUnboundedSize);
74+
}
7375
}
7476
return xla::BroadcastInDim(input, bcast_sizes,
7577
GetAllDimensions(bcast_sizes.size()));
@@ -329,9 +331,9 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
329331
: xla::Reshape(input, shape.dimensions());
330332
}
331333

332-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
333-
334334
bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
335+
XLA_CHECK(experimental_unbounded_dynamism)
336+
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
335337
const absl::Span<const int64_t> dims = shape.dimensions();
336338
return std::any_of(dims.begin(), dims.end(),
337339
[](int64_t size) { return size == kUnboundedSize; });
@@ -340,6 +342,8 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
340342
xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
341343
xla::XlaOp input, xla::XlaOp aux_input,
342344
absl::Span<const int64_t> output_sizes) {
345+
XLA_CHECK(experimental_unbounded_dynamism)
346+
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
343347
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
344348
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
345349
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
@@ -381,13 +385,17 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
381385
xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
382386
xla::XlaOp input, xla::XlaOp aux_input,
383387
absl::Span<const int64_t> aux_input_dimensions) {
388+
XLA_CHECK(experimental_unbounded_dynamism)
389+
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
384390
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
385391
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
386392
bool all_static = true;
387393
std::vector<int64_t> output_dimensions;
388394
std::vector<bool> output_dynamic;
389395
for (auto dim : aux_input_dimensions) {
390-
if (aux_input_shape.dimensions(dim) == kUnboundedSize) all_static = false;
396+
if (aux_input_shape.dimensions(dim) == kUnboundedSize) {
397+
all_static = false;
398+
}
391399
output_dimensions.push_back(aux_input_shape.dimensions(dim));
392400
output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim));
393401
}
@@ -432,13 +440,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
432440
output_dynamic));
433441
}
434442

435-
void XlaHelpers::PrintXlaOp(xla::XlaOp op, const std::string& msg) {
436-
std::cout << "Handle: " << msg << ": " << op << "\n";
437-
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op);
438-
std::cout << xla::ShapeUtil::HumanString(shape);
439-
}
440-
#endif
441-
442443
bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
443444
const xla::Shape& shape2) {
444445
return shape1.is_static() && shape2.is_static() &&
@@ -602,11 +603,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
602603
runtime::util::ToVector<int64_t>(shape1.dimensions()),
603604
runtime::util::ToVector<int64_t>(shape2.dimensions())));
604605
}
605-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
606-
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
607-
!XlaHelpers::IsUnboundedDynamic(shape2))
608-
<< "Unreachable for unbounded dynamic code\n";
609-
#endif
606+
if (experimental_unbounded_dynamism) {
607+
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
608+
!XlaHelpers::IsUnboundedDynamic(shape2))
609+
<< "Unreachable for unbounded dynamic code\n";
610+
}
610611
return GetPromotedDynamicShape(shape1, shape2);
611612
}
612613

@@ -700,7 +701,6 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteSecond(xla::XlaOp op1,
700701
return PromoteShapes(vops.first, vops.second);
701702
}
702703

703-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
704704
xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
705705
xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op,
706706
const xla::Shape& shape) {
@@ -721,7 +721,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
721721
std::vector<xla::XlaOp> reshaped_ops;
722722

723723
if (size_delta > 0) {
724-
std::cout << "\t size_delta > 0\n";
725724
std::vector<int64_t> broadcast_sizes(shape_dims.begin(),
726725
shape_dims.begin() + size_delta);
727726
for (int i = 0; i < size_delta; i++) {
@@ -730,20 +729,15 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
730729
XlaHelpers::ScalarValue<int32_t>(broadcast_sizes[i], op.builder()));
731730

732731
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
733-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
734-
<< " for size: " << broadcast_sizes[i] << "\n";
735732
} else {
736733
get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i));
737734

738735
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
739-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
740-
<< " for size: ? of index: " << i << "\n";
741736
}
742737
}
743738
}
744739

745740
if (size_delta == 0) {
746-
std::cout << "\t size_delta == 0\n";
747741
int sz = op_shape_dims.size() - aux_shape_dims.size();
748742
std::vector<int64_t> broadcast_sizes(shape_dims.begin(),
749743
shape_dims.begin() + sz);
@@ -753,14 +747,10 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
753747
XlaHelpers::ScalarValue<int32_t>(broadcast_sizes[i], op.builder()));
754748

755749
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
756-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
757-
<< " for size: " << broadcast_sizes[i] << "\n";
758750
} else {
759751
get_dim_ops.push_back(xla::GetDimensionSize(op, i));
760752

761753
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
762-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
763-
<< " for size: ? of index: " << i << "\n";
764754
}
765755
}
766756
}
@@ -789,31 +779,23 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
789779
get_dim_ops.push_back(ScalarValue<int32_t>(shape_dim, op.builder()));
790780

791781
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
792-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
793-
<< " for size: " << shape_dim << "\n";
794782

795783
} else if (op_shape_dim == 1 || aux_op_shape_dim == 1) {
796784
if (op_shape_dim == 1) {
797785
get_dim_ops.push_back(
798786
xla::GetDimensionSize(aux_op, aux_op_shape_index));
799787

800788
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
801-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
802-
<< " for size: ? of index: " << aux_op_shape_index << "\n";
803789

804790
} else {
805791
get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index));
806792

807793
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
808-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
809-
<< " for size: ? of index: " << op_shape_index << "\n";
810794
}
811795
} else {
812796
get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index));
813797

814798
auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
815-
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
816-
<< " for size: ? of index: " << op_shape_index << "\n";
817799
}
818800
}
819801

@@ -829,7 +811,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
829811

830812
return new_op;
831813
}
832-
#endif
833814

834815
xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op,
835816
const xla::Shape& op_shape,

torch_xla/csrc/helpers.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ class XlaHelpers {
158158
static xla::XlaOp DynamicReshape(xla::XlaOp input,
159159
absl::Span<const int64_t> output_sizes);
160160

161-
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
162161
static bool IsUnboundedDynamic(const xla::Shape& shape);
163162

164163
static xla::XlaOp DynamicUnboundedReshape(
@@ -169,9 +168,6 @@ class XlaHelpers {
169168
xla::XlaOp input, xla::XlaOp aux_input,
170169
absl::Span<const int64_t> output_sizes);
171170

172-
static void PrintXlaOp(xla::XlaOp op, const std::string& msg);
173-
#endif
174-
175171
static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);
176172

177173
static bool SameStaticDimensions(const xla::Shape& shape1,

torch_xla/csrc/ir.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
#include <functional>
1010
#include <iostream>
1111
#include <memory>
12-
#include <set>
1312
#include <string>
1413
#include <unordered_map>
14+
#include <unordered_set>
1515
#include <utility>
1616
#include <vector>
1717

@@ -138,13 +138,14 @@ class XlaNode : public torch::lazy::Node {
138138

139139
std::string ToString() const override;
140140

141-
void MarkDynamicDimension(uint32_t dim) {
142-
dynamic_dims_.push_back(dim);
141+
void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); }
142+
143+
const std::unordered_set<uint32_t>& dynamic_dims() const {
144+
return dynamic_dims_;
143145
}
144-
const std::vector<uint32_t>& dynamic_dims() const { return dynamic_dims_; }
145146

146147
protected:
147-
std::vector<uint32_t> dynamic_dims_;
148+
std::unordered_set<uint32_t> dynamic_dims_;
148149

149150
private:
150151
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

torch_xla/csrc/lowering_context.cpp

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,18 +93,19 @@ LoweringContext::LoweringContext(
9393
}
9494
}
9595

96-
// import xla::Shape.h to inlcude the following defintion.
96+
// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
9797
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();
98+
9899
xla::XlaOp LoweringContext::GetParameter(
99100
const std::shared_ptr<torch::lazy::BackendData>& data,
100-
const std::vector<uint32_t>& dynamic_dims) {
101+
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
101102
torch::lazy::BackendData::Handle handle = data->GetHandle();
102103
auto it = parameters_map_.find(handle);
103104
if (it == parameters_map_.end()) {
104105
xla::Shape shape =
105106
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
106107
->shape();
107-
for (const int dim : dynamic_dims) {
108+
for (const int dim : unbounded_dynamic_dims) {
108109
shape.set_dynamic_dimension(dim, true);
109110
shape.set_dimensions(dim, kUnboundedSize);
110111
}
@@ -113,6 +114,10 @@ xla::XlaOp LoweringContext::GetParameter(
113114
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
114115
.first;
115116
parameters_.push_back(data);
117+
} else {
118+
XLA_CHECK(unbounded_dynamic_dims.empty())
119+
<< "The unbounded dynamic dims can only be set when Parameter is "
120+
"created.";
116121
}
117122
parameter_sequence_.push_back(it->second.index);
118123
return it->second.param;
@@ -177,19 +182,21 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {
177182

178183
const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
179184
result_ops = casted->Lower(this);
180-
xla::internal::XlaBuilderFriend builder_friend;
181-
auto* inst = builder_friend.GetInstruction(result_ops[0]);
182-
auto* mutable_dynamic =
183-
inst->mutable_shape()->mutable_is_dynamic_dimension();
184-
if (mutable_dynamic->empty()) {
185-
for (int i = 0; i < inst->dimensions_size(); i++) {
186-
mutable_dynamic->Add(false);
185+
if (!casted->dynamic_dims().empty()) {
186+
xla::internal::XlaBuilderFriend builder_friend;
187+
auto* inst = builder_friend.GetInstruction(result_ops[0]);
188+
auto* mutable_dynamic =
189+
inst->mutable_shape()->mutable_is_dynamic_dimension();
190+
if (mutable_dynamic->empty()) {
191+
for (int i = 0; i < inst->dimensions_size(); i++) {
192+
mutable_dynamic->Add(false);
193+
}
194+
}
195+
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
196+
for (const auto dim : casted->dynamic_dims()) {
197+
mutable_dynamic->Set(dim, true);
198+
mutable_dims->Set(dim, kUnboundedSize);
187199
}
188-
}
189-
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
190-
for (const auto dim : casted->dynamic_dims()) {
191-
mutable_dynamic->Set(dim, true);
192-
mutable_dims->Set(dim, kUnboundedSize);
193200
}
194201
} catch (const std::exception& ex) {
195202
ReportBuilderError(node, ex.what());

torch_xla/csrc/lowering_context.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,9 @@ class LoweringContext : public torch::lazy::LoweringContext {
3636
// If a parameter associated with data has already been declared, it will be
3737
// returned. Otherwise a new one will be created, associated with the tensor
3838
// held in data.
39-
xla::XlaOp GetParameter(const std::shared_ptr<torch::lazy::BackendData>& data,
40-
const std::vector<uint32_t>& dynamic_dims = {});
39+
xla::XlaOp GetParameter(
40+
const std::shared_ptr<torch::lazy::BackendData>& data,
41+
const std::unordered_set<uint32_t>& dynamic_dims = {});
4142

4243
// Retrieves the vector holding all the tensors associated with the parameter
4344
// instructions which have been created.

0 commit comments

Comments
 (0)