diff options
author | 2018-08-01 21:44:02 -0700 | |
---|---|---|
committer | 2018-08-01 21:48:45 -0700 | |
commit | e3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch) | |
tree | c4fe57f956fe9334f1e27342a4846250e915fe4a /tensorflow/compiler/xla/service/hlo_instructions.cc | |
parent | 3379bae787d73d6db67d66a284bd1a076b2cbdba (diff) |
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df26a2c744..a571fd574e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2015,4 +2015,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( gather_window_bounds()); } +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join<std::initializer_list<string>>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice<int64> update_window_dims, + tensorflow::gtl::ArraySlice<int64> inserted_window_dims, + tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloScatterInstruction&>(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique<HloScatterInstruction>( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla |