aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 21:44:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 21:48:45 -0700
commite3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch)
treec4fe57f956fe9334f1e27342a4846250e915fe4a /tensorflow/compiler/xla/service/hlo_instructions.cc
parent3379bae787d73d6db67d66a284bd1a076b2cbdba (diff)
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc87
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index df26a2c744..a571fd574e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2015,4 +2015,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
gather_window_bounds());
}
+HloScatterInstruction::HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers)
+ : HloInstruction(HloOpcode::kScatter, shape) {
+ AppendOperand(operand);
+ AppendOperand(scatter_indices);
+ AppendOperand(updates);
+ AppendComputation(update_computation);
+ scatter_dimension_numbers_ =
+ MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+}
+
+string HloScatterInstruction::ScatterDimensionNumbersToString() const {
+ string update_window_dims =
+ StrCat("update_window_dims={",
+ Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string inserted_window_dims = StrCat(
+ "inserted_window_dims={",
+ Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ string scatter_dims_to_operand_dims = StrCat(
+ "scatter_dims_to_operand_dims={",
+ Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ "}");
+ string index_vector_dim = StrCat(
+ "index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
+
+ return Join<std::initializer_list<string>>(
+ {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
+ index_vector_dim},
+ ", ");
+}
+
+/* static */ ScatterDimensionNumbers
+HloScatterInstruction::MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim) {
+ ScatterDimensionNumbers scatter_dim_numbers;
+ for (int64 update_window_dim : update_window_dims) {
+ scatter_dim_numbers.add_update_window_dims(update_window_dim);
+ }
+ for (int64 inserted_window_dim : inserted_window_dims) {
+ scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
+ }
+ for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
+ scatter_dim_numbers.add_scatter_dims_to_operand_dims(
+ scatter_dim_to_operand_dim);
+ }
+ scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
+ return scatter_dim_numbers;
+}
+
+HloInstructionProto HloScatterInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
+ return proto;
+}
+
+std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {ScatterDimensionNumbersToString()};
+}
+
+bool HloScatterInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
+ return protobuf_util::ProtobufEquals(
+ scatter_dimension_numbers(),
+ casted_other.scatter_dimension_numbers()) &&
+ eq_computations(to_apply(), casted_other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloScatterInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
+ scatter_dimension_numbers());
+}
+
} // namespace xla