diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-12 22:19:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-12 22:23:34 -0700 |
commit | 1b765165987f5277e294251c118f321166c70932 (patch) | |
tree | a3b879ec7d36389da1aa66a855aed762917140f8 /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 385bb78761489cfa6d6808f239fe884152e71653 (diff) |
[XLA] Split out HloGatherInstruction as subclass from HloInstruction.
PiperOrigin-RevId: 204421652
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 89 |
1 files changed, 89 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 7ea42caa7b..f333c489ed 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1914,4 +1914,93 @@ HloDynamicSliceInstruction::CloneWithNewOperandsImpl( return MakeUnique<HloDynamicSliceInstruction>( shape, new_operands[0], new_operands[1], dynamic_slice_sizes_); } + +HloGatherInstruction::HloGatherInstruction( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) + : HloInstruction(HloOpcode::kGather, shape) { + AppendOperand(operand); + AppendOperand(gather_indices); + gather_dimension_numbers_ = + MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(gather_window_bounds_)); +} + +string HloGatherInstruction::GatherDimensionNumbersToString() const { + CHECK(gather_dimension_numbers_ != nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join<std::initializer_list<string>>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ GatherDimensionNumbers HloGatherInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + +HloInstructionProto HloGatherInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_gather_dimension_numbers() = gather_dimension_numbers(); + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + return proto; +} + +std::vector<string> HloGatherInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {GatherDimensionNumbersToString(), + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")}; +} + +bool HloGatherInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloGatherInstruction&>(other); + return protobuf_util::ProtobufEquals( + gather_dimension_numbers(), + casted_other.gather_dimension_numbers()) && + gather_window_bounds() == casted_other.gather_window_bounds(); +} + +std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 2); + return MakeUnique<HloGatherInstruction>( + shape, new_operands[0], new_operands[1], gather_dimension_numbers(), + gather_window_bounds()); +} + } // namespace xla |