aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-01 21:44:02 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 21:48:45 -0700
commite3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch)
treec4fe57f956fe9334f1e27342a4846250e915fe4a
parent3379bae787d73d6db67d66a284bd1a076b2cbdba (diff)
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc33
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h14
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc4
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc41
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc49
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc87
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h39
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc188
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h8
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc629
-rw-r--r--tensorflow/compiler/xla/xla_data.proto14
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md132
27 files changed, 1255 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 53be5a79c2..bea3fa9a96 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1635,6 +1635,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices,
});
}
+XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates,
+ const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers) {
+ return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+
+ TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
+ TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape,
+ GetShape(scatter_indices));
+ TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
+ TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
+ update_computation.GetProgramShape());
+ TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
+ ShapeInference::InferScatterShape(
+ input_shape, scatter_indices_shape, updates_shape,
+ to_apply_shape, dimension_numbers));
+
+ *instr.mutable_scatter_dimension_numbers() = dimension_numbers;
+
+ AddCalledComputation(update_computation, &instr);
+ return AddInstruction(std::move(instr), HloOpcode::kScatter,
+ {input, scatter_indices, updates});
+ });
+}
+
XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
const XlaComputation& true_computation,
const XlaOp& false_operand,
@@ -2803,6 +2829,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
window_bounds);
}
+XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers) {
+ return input.builder()->Scatter(input, scatter_indices, updates,
+ update_computation, dimension_numbers);
+}
+
void Send(const XlaOp& operand, const ChannelHandle& handle) {
return operand.builder()->Send(operand, handle);
}
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index ae331407d6..8726cc6f93 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -857,6 +857,11 @@ class XlaBuilder {
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ // Enqueues a Scatter node onto the computation.
+ XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
+
// Enqueues a Send node onto the computation for device-to-device
// communication, to send the given operand to a Recv instruction that shares
// the same channel handle.
@@ -1296,6 +1301,10 @@ class XlaBuilder {
friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates,
+ const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
friend void Send(const XlaOp& operand, const ChannelHandle& handle);
friend XlaOp Recv(XlaBuilder* builder, const Shape& shape,
const ChannelHandle& handle);
@@ -1977,6 +1986,11 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices,
const GatherDimensionNumbers& dimension_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+// Enqueues a Scatter node onto the computation.
+XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices,
+ const XlaOp& updates, const XlaComputation& update_computation,
+ const ScatterDimensionNumbers& dimension_numbers);
+
// Enqueues a Send node onto the computation for device-to-device
// communication. This operation sends the given operand to
// a Recv instruction in a different computation that shares the same channel
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 60f9cd1121..ca645d3f1d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -1793,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) {
return Unimplemented("Send-done is not implemented on CPU.");
}
+Status IrEmitter::HandleScatter(HloInstruction*) {
+ return Unimplemented("Scatter is not implemented on CPUs.");
+}
+
Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 372017441f..c9a1dab62d 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -150,6 +150,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override;
Status HandleConcatenate(HloInstruction* concatenate) override;
Status HandleConditional(HloInstruction* conditional) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* gen_token) override;
Status HandleIota(HloInstruction* iota) override;
Status HandleRng(HloInstruction* rng) override;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 097fa23027..9f86749125 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -233,6 +233,7 @@ class DfsHloVisitorBase {
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
virtual Status HandleGather(HloInstructionPtr hlo) = 0;
+ virtual Status HandleScatter(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index f4316e0fb7..ae8a066d62 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -194,6 +194,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleGather(HloInstructionPtr gather) override {
return DefaultAction(gather);
}
+ Status HandleScatter(HloInstructionPtr scatter) override {
+ return DefaultAction(scatter);
+ }
Status HandleAfterAll(HloInstructionPtr token) override {
return DefaultAction(token);
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 1295e83c0c..290e2f73dc 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -125,6 +125,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) {
return Unimplemented("Recv-done is not implemented on GPU");
}
+Status IrEmitter::HandleScatter(HloInstruction*) {
+ return Unimplemented("Scatter is not implemented on GPUs.");
+}
+
Status IrEmitter::HandleTuple(HloInstruction* tuple) {
std::vector<llvm::Value*> base_ptrs;
for (const HloInstruction* operand : tuple->operands()) {
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
index 80e2a203ac..561c683879 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h
@@ -86,6 +86,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
Status HandleParameter(HloInstruction* parameter) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleTuple(HloInstruction* tuple) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleFusion(HloInstruction* fusion) override;
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 63a8a813cd..0b93d97c11 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -160,6 +160,8 @@ message HloInstructionProto {
// present for Send and Recv instructions and their SendDone and RecvDone
// partners.
bool is_host_transfer = 47;
+
+ xla.ScatterDimensionNumbers scatter_dimension_numbers = 48;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 1f672502f7..7806d432ce 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -648,6 +648,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
return Status::OK();
}
+Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) {
+ // TODO(b/32945756): Compute the properties of the sub-computation.
+ return Status::OK();
+}
+
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 82d650dc7b..d93759a48a 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -104,6 +104,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
Status HandleGather(const HloInstruction* gather) override;
+ Status HandleScatter(const HloInstruction* scatter) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index fd5085bed2..7e5866a356 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1019,6 +1019,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
return kWhite;
}
return kGreen;
+ case HloOpcode::kScatter:
+ // Do not de-emphasize Scatter, since it involves significant work.
case HloOpcode::kCopy:
// Emphasize copy nodes, which are either physical transposes (and thus
// significant), or copies of read-only buffers (and thus dead weight).
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 8b9bdd2f46..402b725bda 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -404,6 +404,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
*gather_dimension_numbers, gather_window_bounds);
break;
}
+ case HloOpcode::kScatter: {
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "Scatter instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_scatter_dimension_numbers())
+ << "Scatter instruction should have ScatterDimensionNumbers set.";
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Scatter instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
+ auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>(
+ proto.scatter_dimension_numbers());
+ instruction =
+ CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
+ computations(0), *scatter_dimension_numbers);
+ break;
+ }
default: {
instruction = WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
@@ -1062,6 +1078,16 @@ bool HloInstruction::HasSideEffect() const {
gather_dim_numbers, window_bounds);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers) {
+ return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices,
+ updates, update_computation,
+ scatter_dim_numbers);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
@@ -1124,6 +1150,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
case HloOpcode::kGather:
+ case HloOpcode::kScatter:
case HloOpcode::kIota:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
@@ -1587,6 +1614,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
+ case HloOpcode::kScatter:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@@ -1693,6 +1721,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -1711,6 +1740,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) {
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@@ -1977,7 +2007,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
- opcode() == HloOpcode::kCrossReplicaSum) {
+ opcode() == HloOpcode::kCrossReplicaSum ||
+ opcode() == HloOpcode::kScatter) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
@@ -2013,6 +2044,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
case HloOpcode::kCrossReplicaSum:
+ case HloOpcode::kScatter:
extra.push_back(
StrCat("to_apply=\n", to_apply()->ToString(new_options)));
break;
@@ -2311,6 +2343,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSendDone(this);
case HloOpcode::kGather:
return visitor->HandleGather(this);
+ case HloOpcode::kScatter:
+ return visitor->HandleScatter(this);
case HloOpcode::kDomain:
return visitor->HandleDomain(this);
case HloOpcode::kAfterAll:
@@ -3171,4 +3205,9 @@ tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds()
return Cast<HloGatherInstruction>(this)->gather_window_bounds();
}
+const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers()
+ const {
+ return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers();
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 0e3130a05c..d2dce5aecb 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -644,6 +644,12 @@ class HloInstruction {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ static std::unique_ptr<HloInstruction> CreateScatter(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+
// Creates a kDomain instruction which delimits an HLO domain which have
// the provided user and operand side metadata.
static std::unique_ptr<HloInstruction> CreateDomain(
@@ -1452,6 +1458,9 @@ class HloInstruction {
// Delegates to HloGatherInstruction::gather_window_bounds.
tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const;
+ // Delegates to HloScatterInstruction::scatter_dimension_numbers().
+ const ScatterDimensionNumbers& scatter_dimension_numbers() const;
+
// Old methods kept for smooth subclassing transition END.
protected:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index b75a2bd34b..8a694dde80 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1425,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) {
"index_vector_dim=2, window_bounds={30,29,28,27,26}");
}
+TEST_F(HloInstructionTest, StringifyScatter) {
+ Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ Shape scatter_indices_tensor_shape =
+ ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
+ Shape scatter_updates_shape =
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26});
+
+ HloComputation::Builder builder("Scatter");
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+ HloInstruction* scatter_indices =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, scatter_indices_tensor_shape, "scatter_indices"));
+ HloInstruction* scatter_updates =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 2, scatter_updates_shape, "scatter_updates"));
+
+ HloComputation::Builder update_builder("Scatter.update");
+ update_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1"));
+ update_builder.AddInstruction(
+ HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2"));
+
+ auto module = CreateNewModule();
+ auto* update_computation =
+ module->AddEmbeddedComputation(update_builder.Build());
+
+ HloInstruction* scatter_instruction =
+ builder.AddInstruction(HloInstruction::CreateScatter(
+ input_tensor_shape, input, scatter_indices, scatter_updates,
+ update_computation,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2)));
+ module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(
+ scatter_instruction->ToString(),
+ "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} "
+ "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+ "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, "
+ "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), "
+ "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, "
+ "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, "
+ "to_apply=%Scatter.update");
+}
+
TEST_F(HloInstructionTest, CanonnicalStringificationFusion) {
// Tests stringification of a simple op, fusion, while, and conditional.
const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10});
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index df26a2c744..a571fd574e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -2015,4 +2015,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl(
gather_window_bounds());
}
+HloScatterInstruction::HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers)
+ : HloInstruction(HloOpcode::kScatter, shape) {
+ AppendOperand(operand);
+ AppendOperand(scatter_indices);
+ AppendOperand(updates);
+ AppendComputation(update_computation);
+ scatter_dimension_numbers_ =
+ MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers);
+}
+
+string HloScatterInstruction::ScatterDimensionNumbersToString() const {
+ string update_window_dims =
+ StrCat("update_window_dims={",
+ Join(scatter_dimension_numbers().update_window_dims(), ","), "}");
+ string inserted_window_dims = StrCat(
+ "inserted_window_dims={",
+ Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}");
+ string scatter_dims_to_operand_dims = StrCat(
+ "scatter_dims_to_operand_dims={",
+ Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","),
+ "}");
+ string index_vector_dim = StrCat(
+ "index_vector_dim=", scatter_dimension_numbers().index_vector_dim());
+
+ return Join<std::initializer_list<string>>(
+ {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims,
+ index_vector_dim},
+ ", ");
+}
+
+/* static */ ScatterDimensionNumbers
+HloScatterInstruction::MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim) {
+ ScatterDimensionNumbers scatter_dim_numbers;
+ for (int64 update_window_dim : update_window_dims) {
+ scatter_dim_numbers.add_update_window_dims(update_window_dim);
+ }
+ for (int64 inserted_window_dim : inserted_window_dims) {
+ scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim);
+ }
+ for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) {
+ scatter_dim_numbers.add_scatter_dims_to_operand_dims(
+ scatter_dim_to_operand_dim);
+ }
+ scatter_dim_numbers.set_index_vector_dim(index_vector_dim);
+ return scatter_dim_numbers;
+}
+
+HloInstructionProto HloScatterInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers();
+ return proto;
+}
+
+std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const {
+ return {ScatterDimensionNumbersToString()};
+}
+
+bool HloScatterInstruction::IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const {
+ const auto& casted_other = static_cast<const HloScatterInstruction&>(other);
+ return protobuf_util::ProtobufEquals(
+ scatter_dimension_numbers(),
+ casted_other.scatter_dimension_numbers()) &&
+ eq_computations(to_apply(), casted_other.to_apply());
+}
+
+std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const {
+ CHECK_EQ(new_operands.size(), 3);
+ return MakeUnique<HloScatterInstruction>(
+ shape, new_operands[0], new_operands[1], new_operands[2], to_apply(),
+ scatter_dimension_numbers());
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 132e767420..3797bef600 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1198,6 +1198,45 @@ class HloGatherInstruction : public HloInstruction {
std::vector<int64> gather_window_bounds_;
};
+class HloScatterInstruction : public HloInstruction {
+ public:
+ explicit HloScatterInstruction(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* scatter_indices, HloInstruction* updates,
+ HloComputation* update_computation,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+ const ScatterDimensionNumbers& scatter_dimension_numbers() const {
+ CHECK(scatter_dimension_numbers_ != nullptr);
+ return *scatter_dimension_numbers_;
+ }
+ // Returns the dump string of the scatter dimension numbers.
+ string ScatterDimensionNumbersToString() const;
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
+ // Creates an instance of ScatterDimensionNumbers.
+ static ScatterDimensionNumbers MakeScatterDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> update_window_dims,
+ tensorflow::gtl::ArraySlice<int64> inserted_window_dims,
+ tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims,
+ int64 index_vector_dim);
+
+ private:
+ std::vector<string> ExtraAttributesToStringImpl(
+ const HloPrintOptions& options) const override;
+ bool IdenticalSlowPath(
+ const HloInstruction& other,
+ const std::function<bool(const HloComputation*, const HloComputation*)>&
+ eq_computations) const override;
+ // Implementation for non-common logic of CloneWithNewOperands.
+ std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
+ const Shape& shape,
+ tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
+ HloCloneContext* context) const override;
+
+ std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_;
+};
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 59e9a5a94a..88531b6f20 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -118,6 +118,7 @@ namespace xla {
V(kReverse, "reverse") \
V(kRng, "rng") \
V(kRoundNearestAfz, "round-nearest-afz") \
+ V(kScatter, "scatter") \
V(kSelect, "select") \
V(kSelectAndScatter, "select-and-scatter") \
V(kSend, "send") \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 6ed6f74c55..3efa264259 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1242,6 +1242,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
dim_numbers, *window_bounds));
break;
}
+ case HloOpcode::kScatter: {
+ // TODO(b/32945756): Implement HLO parsing for Scatter.
+ return TokenError("HLO parsing is not implemented for Scatter.");
+ }
case HloOpcode::kDomain: {
DomainData domain;
attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain};
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 25fa319faf..e4a5cd3af1 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -510,6 +510,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
gather->gather_dimension_numbers(), gather->gather_window_bounds()));
}
+Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
+ return CheckShape(
+ scatter, ShapeInference::InferScatterShape(
+ scatter->operand(0)->shape(), scatter->operand(1)->shape(),
+ scatter->operand(2)->shape(),
+ scatter->to_apply()->ComputeProgramShape(),
+ scatter->scatter_dimension_numbers()));
+}
+
Status ShapeVerifier::HandleAfterAll(HloInstruction* token) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : token->operands()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 79f7aa9f4c..7feddaeabf 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -83,6 +83,7 @@ class ShapeVerifier : public DfsHloVisitor {
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
Status HandleGather(HloInstruction* gather) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleAfterAll(HloInstruction* token) override;
Status FinishVisit(HloInstruction*) override { return Status::OK(); }
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index af07370135..e2191aedb7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -141,6 +141,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kReduceWindow:
case HloOpcode::kRemainder:
case HloOpcode::kRng:
+ case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 35df792b07..20314ca482 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2568,4 +2568,192 @@ static Status ValidateGatherDimensionNumbers(
return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
}
+namespace {
+
+Status ValidateScatterDimensionNumbers(
+ const Shape& operand_shape,
+ tensorflow::gtl::ArraySlice<int64> scatter_indices_shape,
+ const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) {
+ // Validate update_window_dims in ScatterDimensionNumbers.
+ if (!c_is_sorted(dim_numbers.update_window_dims())) {
+ return InvalidArgument(
+ "update_window_dims in scatter op must be sorted; got: %s.",
+ Join(dim_numbers.update_window_dims(), ", ").c_str());
+ }
+ if (c_adjacent_find(dim_numbers.update_window_dims()) !=
+ dim_numbers.update_window_dims().end()) {
+ return InvalidArgument(
+ "update_window_dims in scatter op must not repeat; got: %s.",
+ Join(dim_numbers.update_window_dims(), ", ").c_str());
+ }
+ const int64 updates_rank = ShapeUtil::Rank(updates_shape);
+ for (int64 window_dim : dim_numbers.update_window_dims()) {
+ if (window_dim < 0 || window_dim >= updates_rank) {
+ return InvalidArgument(
+ "Invalid update_window_dims set in scatter op; valid range is [0, "
+ "%lld). got: %lld.",
+ updates_rank, window_dim);
+ }
+ }
+
+ // Validate inserted_window_dims in ScatterDimensionNumbers.
+ if (!c_is_sorted(dim_numbers.inserted_window_dims())) {
+ return InvalidArgument(
+ "inserted_window_dims in scatter op must be sorted; got: %s.",
+ Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ }
+ if (c_adjacent_find(dim_numbers.inserted_window_dims()) !=
+ dim_numbers.inserted_window_dims().end()) {
+ return InvalidArgument(
+ "inserted_window_dims in scatter op must not repeat; got: %s.",
+ Join(dim_numbers.inserted_window_dims(), ", ").c_str());
+ }
+ for (int64 inserted_dim : dim_numbers.inserted_window_dims()) {
+ if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid inserted_window_dims set in scatter op; valid range is [0, "
+ "%d), got: %lld.",
+ operand_shape.dimensions_size(), inserted_dim);
+ }
+ }
+
+ // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers.
+ if (dim_numbers.scatter_dims_to_operand_dims_size() !=
+ scatter_indices_shape[dim_numbers.index_vector_dim()]) {
+ return InvalidArgument(
+ "Scatter op has %d elements in scatter_dims_to_operand_dims and the "
+ "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. "
+ "These two numbers must be equal.",
+ dim_numbers.scatter_dims_to_operand_dims_size(),
+ dim_numbers.index_vector_dim(),
+ scatter_indices_shape[dim_numbers.index_vector_dim()]);
+ }
+ for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) {
+ int64 scatter_dim_to_operand_dim =
+ dim_numbers.scatter_dims_to_operand_dims(i);
+ if (scatter_dim_to_operand_dim < 0 ||
+ scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), "
+ "got: %d->%lld.",
+ operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim);
+ }
+ }
+ std::vector<int64> sorted_scatter_dims_to_operand_dims(
+ dim_numbers.scatter_dims_to_operand_dims().begin(),
+ dim_numbers.scatter_dims_to_operand_dims().end());
+ c_sort(sorted_scatter_dims_to_operand_dims);
+ if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) !=
+ sorted_scatter_dims_to_operand_dims.end()) {
+ return InvalidArgument(
+ "Repeated dimensions not allowed in scatter_dims_to_operand_dims; "
+ "got: %s.",
+ Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str());
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+/*static*/ StatusOr<Shape> ShapeInference::InferScatterShape(
+ const Shape& operand_shape, const Shape& scatter_indices_shape,
+ const Shape& updates_shape, const ProgramShape& to_apply_shape,
+ const ScatterDimensionNumbers& scatter_dim_numbers) {
+ TF_RETURN_IF_ERROR(
+ ExpectArray(operand_shape, "operand tensor of scatter op"));
+ TF_RETURN_IF_ERROR(
+ ExpectArray(scatter_indices_shape, "scatter indices of scatter op"));
+ TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op"));
+
+ if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) {
+ return InvalidArgument(
+ "Scatter indices parameter must be an integral tensor; got %s.",
+ ShapeUtil::HumanString(scatter_indices_shape).c_str());
+ }
+
+ if (scatter_indices_shape.dimensions_size() <
+ scatter_dim_numbers.index_vector_dim() ||
+ scatter_dim_numbers.index_vector_dim() < 0) {
+ return InvalidArgument(
+ "Scatter index leaf dimension must be within [0, rank(scatter_indices)"
+ " + 1). rank(scatter_indices) is %d and scatter index leaf dimension "
+ "is %lld.",
+ scatter_indices_shape.dimensions_size(),
+ scatter_dim_numbers.index_vector_dim());
+ }
+
+ // Check if the update computation has a proper shape as a reduction.
+ TF_RETURN_IF_ERROR(VerifyReducerShape(
+ to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}),
+ updates_shape.element_type()));
+
+ std::vector<int64> expanded_scatter_indices_shape =
+ ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions()));
+ if (expanded_scatter_indices_shape.size() ==
+ scatter_dim_numbers.index_vector_dim()) {
+ expanded_scatter_indices_shape.push_back(1);
+ }
+
+ int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 +
+ scatter_dim_numbers.update_window_dims_size();
+ if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) {
+ return InvalidArgument("Updates tensor must be of rank %lld; got %lld.",
+ expected_updates_rank,
+ ShapeUtil::Rank(updates_shape));
+ }
+
+ TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers(
+ operand_shape, expanded_scatter_indices_shape, updates_shape,
+ scatter_dim_numbers));
+
+ int64 inserted_dims_seen = 0;
+ std::vector<int64> max_update_window_bounds;
+ for (int i = 0; i < operand_shape.dimensions_size(); ++i) {
+ if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() &&
+ scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) {
+ ++inserted_dims_seen;
+ } else {
+ max_update_window_bounds.push_back(operand_shape.dimensions(i));
+ }
+ }
+ for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) {
+ auto update_window_dim = scatter_dim_numbers.update_window_dims(i);
+ if (updates_shape.dimensions(update_window_dim) >
+ max_update_window_bounds[i]) {
+ return InvalidArgument(
+ "Bounds of the window dimensions of updates must not exceed the "
+ "bounds of the corresponding dimensions of operand. For dimension "
+ "%lld, updates bound is %lld, operand bound is %lld.",
+ update_window_dim, updates_shape.dimensions(update_window_dim),
+ max_update_window_bounds[i]);
+ }
+ }
+
+ int64 scatter_dims_seen = 0;
+ for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) {
+ bool is_update_window_dim =
+ c_binary_search(scatter_dim_numbers.update_window_dims(), i);
+ if (is_update_window_dim) {
+ continue;
+ }
+ if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) {
+ ++scatter_dims_seen;
+ }
+ if (updates_shape.dimensions(i) !=
+ expanded_scatter_indices_shape[scatter_dims_seen]) {
+ return InvalidArgument(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices. For "
+ "scatter dimension %lld, updates bound is %lld, scatter_indices "
+ "bound is %lld.",
+ i, updates_shape.dimensions(i),
+ expanded_scatter_indices_shape[scatter_dims_seen]);
+ }
+ ++scatter_dims_seen;
+ }
+
+ return operand_shape;
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 1a5684e3c3..6adea7bc1f 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -268,6 +268,14 @@ class ShapeInference {
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
+ // Helper that validates the given input shape, scatter indices shape, updates
+ // shape, and scatter dimension numbers that constitute a scatter operation,
+ // and returns the result shape of the scatter operation.
+ static StatusOr<Shape> InferScatterShape(
+ const Shape& operand_shape, const Shape& scatter_indices_shape,
+ const Shape& updates_shape, const ProgramShape& to_apply_shape,
+ const ScatterDimensionNumbers& scatter_dim_numbers);
+
private:
// Helper that infers the shape produced by performing an element-wise binary
// operation with the given LHS and RHS shapes.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 6046d50c6d..511d2c22f8 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1536,7 +1536,7 @@ TEST_F(ShapeInferenceTest, BadSort) {
<< statusor.status();
}
-class GatherShapeInferenceTest : public ShapeInferenceTest {
+class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
@@ -1553,9 +1553,13 @@ class GatherShapeInferenceTest : public ShapeInferenceTest {
ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
{s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
+ const ProgramShape to_apply_ =
+ ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
};
-TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
+// Shape inference tests for Gather.
+
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
@@ -1570,7 +1574,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_vector_32_,
@@ -1585,7 +1589,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
ShapeInference::InferGatherShape(
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
@@ -1600,7 +1604,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
+TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1617,7 +1621,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1635,7 +1639,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
ShapeInference::InferGatherShape(
@@ -1653,7 +1657,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
+TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
// This is equivalent to a dynamic slice.
TF_ASSERT_OK_AND_ASSIGN(
Shape gather_shape,
@@ -1671,7 +1675,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
+TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
// The gather indices "tensor" is a scalar S here that's used to slice out
// [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
@@ -1689,7 +1693,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) {
<< ShapeUtil::HumanString(gather_shape);
}
-TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
+TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
tuple_shape_, s64_vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1704,7 +1708,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
+TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, tuple_shape_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1719,7 +1723,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
+TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
s64_vector_32_, vector_32_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1734,7 +1738,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1751,7 +1755,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowIndices) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1768,7 +1772,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1784,7 +1788,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1800,7 +1804,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1818,7 +1822,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1835,7 +1839,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1853,7 +1857,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1872,7 +1876,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1890,7 +1894,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1908,7 +1912,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1924,7 +1928,8 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowBoundsTooLarge) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1940,7 +1945,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1958,7 +1963,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest,
+TEST_F(ScatterGatherShapeInferenceTest,
InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
@@ -1975,7 +1980,7 @@ TEST_F(GatherShapeInferenceTest,
<< statusor.status();
}
-TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
+TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
HloGatherInstruction::MakeGatherDimNumbers(
@@ -1992,5 +1997,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
<< statusor.status();
}
+// Shape inference tests for Scatter.
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_,
+ ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/1)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterWithUpdatesNotMatchingIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterWithUpdatesNotMatchingIndicesV2) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{1},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Bounds of the window dimensions of updates must not exceed "
+ "the bounds of the corresponding dimensions of operand."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ TfScatterNdWithUpdatesNotMatchingIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Bounds of the scatter dimensions of updates must be same as the "
+ "bounds of the corresponding dimensions of scatter indices."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4)));
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/2)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
+ ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
+ to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 8},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
+ // This is equivalent to a dynamic update slice.
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0, 1, 2, 3, 4},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
+ // The scalar indices "tensor" is a scalar S here that's used to update a
+ // [30,29,28,27] shaped tensor within the operand at position S.
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape scatter_shape,
+ ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
+ ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0, 1, 2, 3},
+ /*inserted_window_dims=*/{0},
+ /*scatter_dims_to_operand_dims=*/{0},
+ /*index_vector_dim=*/0)));
+
+ EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
+ << ShapeUtil::HumanString(scatter_shape);
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/1));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for operand"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ ScatterWithTupleShapedScatterIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for scatter indices"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected array argument for updates"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{0},
+ /*inserted_window_dims=*/{1},
+ /*scatter_dims_to_operand_dims=*/{1},
+ /*index_vector_dim=*/0));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Scatter indices parameter must be an integral tensor"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/10));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Scatter index leaf dimension must be within [0, "
+ "rank(scatter_indices) + 1)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Updates tensor must be of rank 7; got 8."))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
+ const ProgramShape invalid_update_computation =
+ ShapeUtil::MakeProgramShape({f32_}, f32_);
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
+ invalid_update_computation,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Reduction function must take 2 parameters, but takes 1"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 8, 7},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("update_window_dims in scatter op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 7},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("update_window_dims in scatter op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6, 7, 9},
+ /*inserted_window_dims=*/{},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid update_window_dims set in scatter op; valid "
+ "range is [0, 9)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{2, 1},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("inserted_window_dims in scatter op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 1},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("inserted_window_dims in scatter op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 5},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
+ "range is [0, 5)"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
+ "the bound of dimension index_vector_dim=4 of scatter_indices "
+ "is 5. These two numbers must be equal"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
+ "is [0, 5), got: 4->10"))
+ << statusor.status();
+}
+
+TEST_F(ScatterGatherShapeInferenceTest,
+ InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
+ HloScatterInstruction::MakeScatterDimNumbers(
+ /*update_window_dims=*/{4, 5, 6},
+ /*inserted_window_dims=*/{1, 2},
+ /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
+ /*index_vector_dim=*/4));
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 0b300dc7b2..fd784e909c 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -447,6 +447,20 @@ message GatherDimensionNumbers {
int64 index_vector_dim = 4;
}
+// Describes the dimension numbers for a scatter operation.
+//
+// All the fields are similar to the corresponding fields in
+// GatherDimensionNumbers. Differences are noted below.
+message ScatterDimensionNumbers {
+ // The set of dimensions in the updates shape that are window dimensions.
+ repeated int64 update_window_dims = 1;
+ // The set of window dimensions that must be inserted into the updates shape.
+ repeated int64 inserted_window_dims = 2;
+
+ repeated int64 scatter_dims_to_operand_dims = 3;
+ int64 index_vector_dim = 4;
+}
+
message ConvolutionDimensionNumbers {
// The number of the dimension that represents batch in the input.
int64 input_batch_dimension = 7;
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 5f7482f90f..3981aaaf75 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1801,6 +1801,138 @@ is implementation-defined.
: : : limit of interval :
| `shape` | `Shape` | Output shape of type T |
+## Scatter
+
+The XLA scatter operation generates a result which is the value of the input
+tensor `operand`, with several slices (at indices specified by
+`scatter_indices`) updated with the values in `updates` using
+`update_computation`.
+
+See also
+[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h).
+
+<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
+
+|Arguments | Type | Semantics |
+|------------------|------------------------|----------------------------------|
+|`operand` | `XlaOp` | Tensor to be scattered into. |
+|`scatter_indices` | `XlaOp` | Tensor containing the starting |
+: : : indices of the slices that must :
+: : : be scattered to. :
+|`updates` | `XlaOp` | Tensor containing the values that|
+: : : must be used for scattering. :
+|`update_computation`| `XlaComputation` | Computation to be used for |
+: : : combining the existing values in :
+: : : the input tensor and the updates :
+: : : during scatter. This computation :
+: : : should be of type `T, T -> T`. :
+|`index_vector_dim`| `int64` | The dimension in |
+: : : `scatter_indices` that contains :
+: : : the starting indices. :
+|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in |
+: : : `updates` shape that are _window :
+: : : dimensions_. :
+|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ |
+: : : that must be inserted into :
+: : : `updates` shape. :
+|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from |
+: : : the scatter indices to the :
+: : : operand index space. This array :
+: : : is interpreted as mapping `i` to :
+: : : `scatter_dims_to_operand_dims[i]`:
+: : : . It has to be one-to-one and :
+: : : total. :
+
+If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
+`scatter_indices` to have a trailing `1` dimension.
+
+We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
+dimensions in `updates` shape that are not in `update_window_dims`, in ascending
+order.
+
+The arguments of scatter should follow these constraints:
+
+ - `updates` tensor must be of rank `update_window_dims.size +
+ scatter_indices.rank - 1`.
+
+ - Bounds of dimension `i` in `updates` must conform to the following:
+ - If `i` is present in `update_window_dims` (i.e. equal to
+ `update_window_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must not exceed the corresponding bound of `operand`
+ after accounting for the `inserted_window_dims` (i.e.
+ `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
+ the bounds of `operand` with the bounds at indices
+ `inserted_window_dims` removed).
+ - If `i` is present in `update_scatter_dims` (i.e. equal to
+ `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must be equal to the corresponding bound of
+ `scatter_indices`, skipping `index_vector_dim` (i.e.
+ `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
+ `scatter_indices.shape.dims`[`k+1`] otherwise).
+
+ - `update_window_dims` must be in ascending order, not have any repeating
+ dimension numbers, and be in the range `[0, updates.rank)`.
+
+ - `inserted_window_dims` must be in ascending order, not have any
+ repeating dimension numbers, and be in the range `[0, operand.rank)`.
+
+ - `scatter_dims_to_operand_dims.size` must be equal to
+ `scatter_indices`[`index_vector_dim`], and its values must be in the range
+ `[0, operand.rank)`.
+
+For a given index `U` in the `updates` tensor, the corresponding index `I` in
+the `operand` tensor into which this update has to be applied is computed as
+follows:
+
+ 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
+ an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] =
+ `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
+ positions `index_vector_dim` into A.
+ 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
+ `S` using the `scatter_dims_to_operand_dims` map. More formally:
+ 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
+ `k` < `scatter_dims_to_operand_dims.size`.
+ 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
+ at `update_window_dims` in `U` according to `inserted_window_dims`.
+ More formally:
+ 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if
+ `k` < `update_window_dims.size`, where `window_dims_to_operand_dims`
+ is the monotonic function with domain [`0`, `update_window_dims.size`)
+ and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For
+ example, if `update_window_dims.size` is `4`, `operand.rank` is `6`,
+ and `inserted_window_dims` is {`0`, `2`} then
+ `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`,
+ `3`→`5`}).
+ 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+ addition.
+
+In summary, the scatter operation can be defined as follows.
+
+ - Initialize `output` with `operand`, i.e. for all indices `O` in the
+ `operand` tensor:\
+ `output`[`O`] = `operand`[`O`]
+ - For every index `U` in the `updates` tensor and the corresponding index `O`
+ in the `operand` tensor:\
+ `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
+
+The order in which updates are applied is non-deterministic. So, when multiple
+indices in `updates` refer to the same index in `operand`, the corresponding
+value in `output` will be non-deterministic.
+
+Note that the first parameter that is passed into the `update_computation` will
+always be the current value from the `output` tensor and the second parameter
+will always be the value from the `updates` tensor. This is important
+specifically for cases when the `update_computation` is _not commutative_.
+
+Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
+the scatter op updates the elements in the input that are extracted by the
+corresponding gather op.
+
+For a detailed informal description and examples, refer to the
+"Informal Description" section under `Gather`.
+
## Select
See also