From e3bcc0aa6e52867d9a12d9efded921325ecc5966 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Aug 2018 21:44:02 -0700 Subject: [XLA] Add Scatter HLO. PiperOrigin-RevId: 207045468 --- .../compiler/xla/service/hlo_instructions.cc | 87 ++++++++++++++++++++++ 1 file changed, 87 insertions(+) (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc') diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df26a2c744..a571fd574e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2015,4 +2015,91 @@ std::unique_ptr HloGatherInstruction::CloneWithNewOperandsImpl( gather_window_bounds()); } +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice update_window_dims, + tensorflow::gtl::ArraySlice inserted_window_dims, + tensorflow::gtl::ArraySlice scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla -- cgit v1.2.3