diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-12 22:19:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 22:23:34 -0700 |
commit | 1b765165987f5277e294251c118f321166c70932 (patch) | |
tree | a3b879ec7d36389da1aa66a855aed762917140f8 | |
parent | 385bb78761489cfa6d6808f239fe884152e71653 (diff) |
[XLA] Split out HloGatherInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 204421652
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 109 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 29 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction_test.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 89 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 43 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/shape_inference_test.cc | 122 |
8 files changed, 243 insertions, 167 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 85c6c632cd..989bb759e3 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -182,6 +182,7 @@ tf_cc_test( name = "shape_inference_test", srcs = ["shape_inference_test.cc"], deps = [ + ":hlo", ":shape_inference", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 830ebfb125..19bee38790 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -386,6 +386,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( slice_sizes); break; } + case HloOpcode::kGather: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Gather instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_gather_dimension_numbers()) + << "Gather instruction should have GatherDimensionNumbers set."; + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers = + MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); + std::vector<int64> gather_window_bounds; + for (int64 bound : proto.gather_window_bounds()) { + gather_window_bounds.push_back(bound); + } + instruction = + CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_window_bounds); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -427,13 +444,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } - if (proto.has_gather_dimension_numbers()) { - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); - } - for (int64 bound : proto.gather_window_bounds()) { - instruction->gather_window_bounds_.push_back(bound); - } return std::move(instruction); } @@ -1036,34 +1046,8 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds) { - std::unique_ptr<HloInstruction> instruction = - WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(gather_indices); - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); - return instruction; -} - -/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim) { - GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); - } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); - } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); - } - - gather_dim_numbers.set_index_vector_dim(index_vector_dim); - return gather_dim_numbers; + return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices, + gather_dim_numbers, window_bounds); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( @@ -1127,6 +1111,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kSort: + case HloOpcode::kGather: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1228,11 +1213,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); - break; case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); clone = @@ -1539,11 +1519,6 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); - case HloOpcode::kGather: - return protobuf_util::ProtobufEquals(gather_dimension_numbers(), - other.gather_dimension_numbers()) && - gather_window_bounds() == other.gather_window_bounds(); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1590,6 +1565,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1955,11 +1931,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } - if (gather_dimension_numbers_ != nullptr) { - extra.push_back(GatherDimensionNumbersToString()); - extra.push_back( - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2089,14 +2060,6 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - if (opcode() == HloOpcode::kGather) { - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); - } - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2857,26 +2820,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } -string HloInstruction::GatherDimensionNumbersToString() const { - CHECK_NE(gather_dimension_numbers_.get(), nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - - return Join<std::initializer_list<string>>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, - ", "); -} - bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3190,4 +3133,14 @@ int64 HloInstruction::slice_sizes(int64 dimension) const { const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const { return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes(); } + +const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { + return Cast<HloGatherInstruction>(this)->gather_dimension_numbers(); +} + +tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds() + const { + return Cast<HloGatherInstruction>(this)->gather_window_bounds(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index b392d65636..cbd78fa124 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -700,13 +700,6 @@ class HloInstruction { // when we plumb a primordial token from the entry computation. static std::unique_ptr<HloInstruction> CreateToken(); - // Creates an instance of GatherDimensionNumbers. - static GatherDimensionNumbers MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim); - // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -1081,19 +1074,6 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; - const GatherDimensionNumbers& gather_dimension_numbers() const { - CHECK(gather_dimension_numbers_ != nullptr); - return *gather_dimension_numbers_; - } - - tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { - CHECK_EQ(opcode(), HloOpcode::kGather); - return gather_window_bounds_; - } - - // Returns the dump string of the gather dimension numbers. - string GatherDimensionNumbersToString() const; - // Clones the HLO instruction. The clone will have the same opcode, shape, and // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of @@ -1460,6 +1440,12 @@ class HloInstruction { // Delegates to HloDynamicSliceInstruction::dynamic_slice_sizes. const std::vector<int64>& dynamic_slice_sizes() const; + + // Delegates to HloGatherInstruction::gather_dimension_numbers. + const GatherDimensionNumbers& gather_dimension_numbers() const; + // Delegates to HloGatherInstruction::gather_window_bounds. + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const; + // Old methods kept for smooth subclassing transition END. protected: @@ -1603,9 +1589,6 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; - std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; - std::vector<int64> gather_window_bounds_; - // Used to tag kCopy instructions that are eligible for copy elision. bool copy_elision_allowed_ = true; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 87c048930f..b75a2bd34b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" @@ -1369,7 +1370,7 @@ TEST_F(HloInstructionTest, StringifyGather_0) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1405,7 +1406,7 @@ TEST_F(HloInstructionTest, StringifyGather_1) { HloInstruction* gather_instruction = builder.AddInstruction(HloInstruction::CreateGather( gather_result_shape, input, gather_indices, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7ea42caa7b..f333c489ed 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1914,4 +1914,93 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( return MakeUnique<HloDynamicSliceInstruction>( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } + +HloGatherInstruction::HloGatherInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) + : HloInstruction(HloOpcode::kGather, shape) { + AppendOperand(operand); + AppendOperand(gather_indices); + gather_dimension_numbers_ = + MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); +} + +string HloGatherInstruction::GatherDimensionNumbersToString() const { + CHECK(gather_dimension_numbers_ != nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join<std::initializer_list<string>>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + +HloInstructionProto HloGatherInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + return proto; +} + +std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {GatherDimensionNumbersToString(), + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; +} + +bool HloGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloGatherInstruction&>(other); + return protobuf_util::ProtobufEquals( + gather_dimension_numbers(), + casted_other.gather_dimension_numbers()) && + gather_window_bounds() == casted_other.gather_window_bounds(); +} + +std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloGatherInstruction>( + shape, new_operands[0], new_operands[1], gather_dimension_numbers(), + gather_window_bounds()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e922d94234..65a93cdcf1 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1148,6 +1148,49 @@ class HloDynamicSliceInstruction : public HloInstruction { // ('start' is specified dynamically in the second operand of the operation). std::vector<int64> dynamic_slice_sizes_; }; + +class HloGatherInstruction : public HloInstruction { + public: + explicit HloGatherInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { + return gather_window_bounds_; + } + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; + std::vector<int64> gather_window_bounds_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index f162d52d3c..d387539350 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_domain_metadata.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -1192,11 +1193,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } - GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/*output_window_dims, - /*elided_window_dims=*/*elided_window_dims, - /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, - /*index_vector_dim=*/*index_vector_dim); + GatherDimensionNumbers dim_numbers = + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); instruction = builder->AddInstruction(HloInstruction::CreateGather( shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index bafe14d6f4..9b1ce143c6 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include <string> +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -1543,45 +1544,45 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { }; TEST_F(GatherShapeInferenceTest, TensorFlowGather) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), - /*window_bounds=*/{64, 1})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{1}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/1), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_vector_32_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); EXPECT_TRUE( ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) << ShapeUtil::HumanString(gather_shape); } TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { - TF_ASSERT_OK_AND_ASSIGN( - Shape gather_shape, - ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, - HloInstruction::MakeGatherDimNumbers( - /*output_window_dims=*/{4}, - /*elided_window_dims=*/{0}, - /*gather_dims_to_operand_dims=*/{0}, - /*index_vector_dim=*/4), - /*window_bounds=*/{1, 48})); + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); EXPECT_TRUE(ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) << ShapeUtil::HumanString(gather_shape); @@ -1592,7 +1593,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1609,7 +1610,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1627,7 +1628,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1646,7 +1647,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3, 4}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1664,7 +1665,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_scalar_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{0, 1, 2, 3}, /*elided_window_dims=*/{0}, /*gather_dims_to_operand_dims=*/{0}, @@ -1679,10 +1680,11 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/1), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1693,10 +1695,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1707,10 +1710,11 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, - HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, - /*elided_window_dims=*/{1}, - /*gather_dims_to_operand_dims=*/{1}, - /*index_vector_dim=*/0), + HloGatherInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), /*window_bounds=*/{64, 1}); ASSERT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), @@ -1722,7 +1726,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 8, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1739,7 +1743,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 7}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1756,7 +1760,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 99, 100, 101}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1772,7 +1776,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 9}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1788,7 +1792,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{4}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1806,7 +1810,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 19}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1823,7 +1827,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{0, 1, 2, 3, 3}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1841,7 +1845,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, @@ -1860,7 +1864,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, @@ -1878,7 +1882,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, @@ -1896,7 +1900,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{2, 1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1911,7 +1915,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{2}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1928,7 +1932,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1946,7 +1950,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7}, /*elided_window_dims=*/{1}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, @@ -1962,7 +1966,7 @@ TEST_F(GatherShapeInferenceTest, TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, - HloInstruction::MakeGatherDimNumbers( + HloGatherInstruction::MakeGatherDimNumbers( /*output_window_dims=*/{4, 5, 6, 7, 8}, /*elided_window_dims=*/{}, /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, |