2121namespace torch_xla {
2222namespace {
2323
24- #if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
24+ static const bool experimental_unbounded_dynamism =
25+ runtime::sys_util::GetEnvBool (" EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM" , false );
26+
27+ // TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
2528static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t >::min();
26- #endif
2729
2830xla::XlaOp ConvertBinaryOpResult (xla::XlaOp op1, xla::XlaOp op2,
2931 xla::XlaOp result) {
@@ -67,9 +69,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
6769 std::vector<int64_t > bcast_sizes = SizesOfXlaOp (input);
6870 for (size_t i = 0 ; i < dimensions.size (); ++i) {
6971 bcast_sizes.at (dimensions[i]) = sizes[i];
70- # if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
71- XLA_CHECK (sizes[i] != kUnboundedSize );
72- # endif
72+ if (experimental_unbounded_dynamism) {
73+ XLA_CHECK (sizes[i] != kUnboundedSize );
74+ }
7375 }
7476 return xla::BroadcastInDim (input, bcast_sizes,
7577 GetAllDimensions (bcast_sizes.size ()));
@@ -329,9 +331,9 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
329331 : xla::Reshape (input, shape.dimensions ());
330332}
331333
332- #if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
333-
334334bool XlaHelpers::IsUnboundedDynamic (const xla::Shape& shape) {
335+ XLA_CHECK (experimental_unbounded_dynamism)
336+ << " EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on." ;
335337 const absl::Span<const int64_t > dims = shape.dimensions ();
336338 return std::any_of (dims.begin (), dims.end (),
337339 [](int64_t size) { return size == kUnboundedSize ; });
@@ -340,6 +342,8 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
340342xla::XlaOp XlaHelpers::DynamicUnboundedReshape (
341343 xla::XlaOp input, xla::XlaOp aux_input,
342344 absl::Span<const int64_t > output_sizes) {
345+ XLA_CHECK (experimental_unbounded_dynamism)
346+ << " EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on." ;
343347 const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp (aux_input);
344348 XLA_CHECK (output_sizes.size () == aux_input_shape.rank ())
345349 << " XlaHelpers::DynamicUnboundedReshape constrainled failed!" ;
@@ -381,13 +385,17 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
381385xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast (
382386 xla::XlaOp input, xla::XlaOp aux_input,
383387 absl::Span<const int64_t > aux_input_dimensions) {
388+ XLA_CHECK (experimental_unbounded_dynamism)
389+ << " EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on." ;
384390 const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp (input);
385391 const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp (aux_input);
386392 bool all_static = true ;
387393 std::vector<int64_t > output_dimensions;
388394 std::vector<bool > output_dynamic;
389395 for (auto dim : aux_input_dimensions) {
390- if (aux_input_shape.dimensions (dim) == kUnboundedSize ) all_static = false ;
396+ if (aux_input_shape.dimensions (dim) == kUnboundedSize ) {
397+ all_static = false ;
398+ }
391399 output_dimensions.push_back (aux_input_shape.dimensions (dim));
392400 output_dynamic.push_back (aux_input_shape.is_dynamic_dimension (dim));
393401 }
@@ -432,13 +440,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
432440 output_dynamic));
433441}
434442
435- void XlaHelpers::PrintXlaOp (xla::XlaOp op, const std::string& msg) {
436- std::cout << " Handle: " << msg << " : " << op << " \n " ;
437- const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp (op);
438- std::cout << xla::ShapeUtil::HumanString (shape);
439- }
440- #endif
441-
442443bool XlaHelpers::SameStaticDimensions (const xla::Shape& shape1,
443444 const xla::Shape& shape2) {
444445 return shape1.is_static () && shape2.is_static () &&
@@ -602,11 +603,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
602603 runtime::util::ToVector<int64_t >(shape1.dimensions ()),
603604 runtime::util::ToVector<int64_t >(shape2.dimensions ())));
604605 }
605- # if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
606- XLA_CHECK (!XlaHelpers::IsUnboundedDynamic (shape1) &&
607- !XlaHelpers::IsUnboundedDynamic (shape2))
608- << " Unreachable for unbounded dynamic code\n " ;
609- # endif
606+ if (experimental_unbounded_dynamism) {
607+ XLA_CHECK (!XlaHelpers::IsUnboundedDynamic (shape1) &&
608+ !XlaHelpers::IsUnboundedDynamic (shape2))
609+ << " Unreachable for unbounded dynamic code\n " ;
610+ }
610611 return GetPromotedDynamicShape (shape1, shape2);
611612}
612613
@@ -700,7 +701,6 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteSecond(xla::XlaOp op1,
700701 return PromoteShapes (vops.first , vops.second );
701702}
702703
703- #if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
704704xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes (
705705 xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op,
706706 const xla::Shape& shape) {
@@ -721,7 +721,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
721721 std::vector<xla::XlaOp> reshaped_ops;
722722
723723 if (size_delta > 0 ) {
724- std::cout << " \t size_delta > 0\n " ;
725724 std::vector<int64_t > broadcast_sizes (shape_dims.begin (),
726725 shape_dims.begin () + size_delta);
727726 for (int i = 0 ; i < size_delta; i++) {
@@ -730,20 +729,15 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
730729 XlaHelpers::ScalarValue<int32_t >(broadcast_sizes[i], op.builder ()));
731730
732731 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
733- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
734- << " for size: " << broadcast_sizes[i] << " \n " ;
735732 } else {
736733 get_dim_ops.push_back (xla::GetDimensionSize (aux_op, i));
737734
738735 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
739- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
740- << " for size: ? of index: " << i << " \n " ;
741736 }
742737 }
743738 }
744739
745740 if (size_delta == 0 ) {
746- std::cout << " \t size_delta == 0\n " ;
747741 int sz = op_shape_dims.size () - aux_shape_dims.size ();
748742 std::vector<int64_t > broadcast_sizes (shape_dims.begin (),
749743 shape_dims.begin () + sz);
@@ -753,14 +747,10 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
753747 XlaHelpers::ScalarValue<int32_t >(broadcast_sizes[i], op.builder ()));
754748
755749 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
756- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
757- << " for size: " << broadcast_sizes[i] << " \n " ;
758750 } else {
759751 get_dim_ops.push_back (xla::GetDimensionSize (op, i));
760752
761753 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
762- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
763- << " for size: ? of index: " << i << " \n " ;
764754 }
765755 }
766756 }
@@ -789,31 +779,23 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
789779 get_dim_ops.push_back (ScalarValue<int32_t >(shape_dim, op.builder ()));
790780
791781 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
792- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
793- << " for size: " << shape_dim << " \n " ;
794782
795783 } else if (op_shape_dim == 1 || aux_op_shape_dim == 1 ) {
796784 if (op_shape_dim == 1 ) {
797785 get_dim_ops.push_back (
798786 xla::GetDimensionSize (aux_op, aux_op_shape_index));
799787
800788 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
801- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
802- << " for size: ? of index: " << aux_op_shape_index << " \n " ;
803789
804790 } else {
805791 get_dim_ops.push_back (xla::GetDimensionSize (op, op_shape_index));
806792
807793 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
808- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
809- << " for size: ? of index: " << op_shape_index << " \n " ;
810794 }
811795 } else {
812796 get_dim_ops.push_back (xla::GetDimensionSize (op, op_shape_index));
813797
814798 auto s = ShapeHelper::ShapeOfXlaOp (get_dim_ops.back ());
815- std::cout << " implicitB shape: " << xla::ShapeUtil::HumanString (s)
816- << " for size: ? of index: " << op_shape_index << " \n " ;
817799 }
818800 }
819801
@@ -829,7 +811,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
829811
830812 return new_op;
831813}
832- #endif
833814
834815xla::XlaOp XlaHelpers::ImplicitBroadcast (xla::XlaOp op,
835816 const xla::Shape& op_shape,
0 commit comments