diff options
author | 2018-08-01 21:44:02 -0700 | |
---|---|---|
committer | 2018-08-01 21:48:45 -0700 | |
commit | e3bcc0aa6e52867d9a12d9efded921325ecc5966 (patch) | |
tree | c4fe57f956fe9334f1e27342a4846250e915fe4a | |
parent | 3379bae787d73d6db67d66a284bd1a076b2cbdba (diff) |
[XLA] Add Scatter HLO.
PiperOrigin-RevId: 207045468
27 files changed, 1255 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 53be5a79c2..bea3fa9a96 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1635,6 +1635,32 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& gather_indices, }); } +XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + + TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input)); + TF_ASSIGN_OR_RETURN(const Shape& scatter_indices_shape, + GetShape(scatter_indices)); + TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates)); + TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape, + update_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN(*instr.mutable_shape(), + ShapeInference::InferScatterShape( + input_shape, scatter_indices_shape, updates_shape, + to_apply_shape, dimension_numbers)); + + *instr.mutable_scatter_dimension_numbers() = dimension_numbers; + + AddCalledComputation(update_computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kScatter, + {input, scatter_indices, updates}); + }); +} + XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand, const XlaComputation& true_computation, const XlaOp& false_operand, @@ -2803,6 +2829,13 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, window_bounds); } +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers) { + return input.builder()->Scatter(input, scatter_indices, updates, + update_computation, dimension_numbers); +} + void Send(const XlaOp& operand, const ChannelHandle& handle) { return operand.builder()->Send(operand, handle); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index ae331407d6..8726cc6f93 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -857,6 +857,11 @@ class XlaBuilder { const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); + // Enqueues a Scatter node onto the computation. + XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + // Enqueues a Send node onto the computation for device-to-device // communication, to send the given operand to a Recv instruction that shares // the same channel handle. @@ -1296,6 +1301,10 @@ class XlaBuilder { friend XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); + friend XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, + const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); friend void Send(const XlaOp& operand, const ChannelHandle& handle); friend XlaOp Recv(XlaBuilder* builder, const Shape& shape, const ChannelHandle& handle); @@ -1977,6 +1986,11 @@ XlaOp Gather(const XlaOp& input, const XlaOp& gather_indices, const GatherDimensionNumbers& dimension_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); +// Enqueues a Scatter node onto the computation. +XlaOp Scatter(const XlaOp& input, const XlaOp& scatter_indices, + const XlaOp& updates, const XlaComputation& update_computation, + const ScatterDimensionNumbers& dimension_numbers); + // Enqueues a Send node onto the computation for device-to-device // communication. This operation sends the given operand to // a Recv instruction in a different computation that shares the same channel diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 60f9cd1121..ca645d3f1d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1793,6 +1793,10 @@ Status IrEmitter::HandleSendDone(HloInstruction* send_done) { return Unimplemented("Send-done is not implemented on CPU."); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on CPUs."); +} + Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 372017441f..c9a1dab62d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -150,6 +150,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleConditional(HloInstruction* conditional) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* gen_token) override; Status HandleIota(HloInstruction* iota) override; Status HandleRng(HloInstruction* rng) override; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 097fa23027..9f86749125 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -233,6 +233,7 @@ class DfsHloVisitorBase { virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; virtual Status HandleGather(HloInstructionPtr hlo) = 0; + virtual Status HandleScatter(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index f4316e0fb7..ae8a066d62 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -194,6 +194,9 @@ class DfsHloVisitorWithDefaultBase Status HandleGather(HloInstructionPtr gather) override { return DefaultAction(gather); } + Status HandleScatter(HloInstructionPtr scatter) override { + return DefaultAction(scatter); + } Status HandleAfterAll(HloInstructionPtr token) override { return DefaultAction(token); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 1295e83c0c..290e2f73dc 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -125,6 +125,10 @@ Status IrEmitter::HandleRecvDone(HloInstruction*) { return Unimplemented("Recv-done is not implemented on GPU"); } +Status IrEmitter::HandleScatter(HloInstruction*) { + return Unimplemented("Scatter is not implemented on GPUs."); +} + Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector<llvm::Value*> base_ptrs; for (const HloInstruction* operand : tuple->operands()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 80e2a203ac..561c683879 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -86,6 +86,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleParameter(HloInstruction* parameter) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleTuple(HloInstruction* tuple) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleFusion(HloInstruction* fusion) override; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 63a8a813cd..0b93d97c11 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -160,6 +160,8 @@ message HloInstructionProto { // present for Send and Recv instructions and their SendDone and RecvDone // partners. bool is_host_transfer = 47; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 1f672502f7..7806d432ce 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -648,6 +648,11 @@ Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { return Status::OK(); } +Status HloCostAnalysis::HandleScatter(const HloInstruction* scatter) { + // TODO(b/32945756): Compute the properties of the sub-computation. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 82d650dc7b..d93759a48a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -104,6 +104,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; Status HandleGather(const HloInstruction* gather) override; + Status HandleScatter(const HloInstruction* scatter) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index fd5085bed2..7e5866a356 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1019,6 +1019,8 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kWhite; } return kGreen; + case HloOpcode::kScatter: + // Do not de-emphasize Scatter, since it involves significant work. case HloOpcode::kCopy: // Emphasize copy nodes, which are either physical transposes (and thus // significant), or copies of read-only buffers (and thus dead weight). diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 8b9bdd2f46..402b725bda 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -404,6 +404,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( *gather_dimension_numbers, gather_window_bounds); break; } + case HloOpcode::kScatter: { + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "Scatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_scatter_dimension_numbers()) + << "Scatter instruction should have ScatterDimensionNumbers set."; + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Scatter instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); + auto scatter_dimension_numbers = MakeUnique<ScatterDimensionNumbers>( + proto.scatter_dimension_numbers()); + instruction = + CreateScatter(proto.shape(), operands(0), operands(1), operands(2), + computations(0), *scatter_dimension_numbers); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -1062,6 +1078,16 @@ bool HloInstruction::HasSideEffect() const { gather_dim_numbers, window_bounds); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) { + return MakeUnique<HloScatterInstruction>(shape, operand, scatter_indices, + updates, update_computation, + scatter_dim_numbers); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( const Shape& shape, HloInstruction* operand, std::unique_ptr<DomainMetadata> operand_side_metadata, @@ -1124,6 +1150,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kDynamicSlice: case HloOpcode::kSort: case HloOpcode::kGather: + case HloOpcode::kScatter: case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; @@ -1587,6 +1614,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kPad: case HloOpcode::kDynamicSlice: case HloOpcode::kGather: + case HloOpcode::kScatter: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -1693,6 +1721,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -1711,6 +1740,7 @@ void HloInstruction::set_to_apply(HloComputation* computation) { case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: CHECK_EQ(called_computations_.size(), 1); called_computations_[0] = computation; break; @@ -1977,7 +2007,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString( } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || opcode() == HloOpcode::kReduce || - opcode() == HloOpcode::kCrossReplicaSum) { + opcode() == HloOpcode::kCrossReplicaSum || + opcode() == HloOpcode::kScatter) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { @@ -2013,6 +2044,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString( case HloOpcode::kReduceWindow: case HloOpcode::kReduce: case HloOpcode::kCrossReplicaSum: + case HloOpcode::kScatter: extra.push_back( StrCat("to_apply=\n", to_apply()->ToString(new_options))); break; @@ -2311,6 +2343,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleSendDone(this); case HloOpcode::kGather: return visitor->HandleGather(this); + case HloOpcode::kScatter: + return visitor->HandleScatter(this); case HloOpcode::kDomain: return visitor->HandleDomain(this); case HloOpcode::kAfterAll: @@ -3171,4 +3205,9 @@ tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds() return Cast<HloGatherInstruction>(this)->gather_window_bounds(); } +const ScatterDimensionNumbers& HloInstruction::scatter_dimension_numbers() + const { + return Cast<HloScatterInstruction>(this)->scatter_dimension_numbers(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 0e3130a05c..d2dce5aecb 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -644,6 +644,12 @@ class HloInstruction { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); + static std::unique_ptr<HloInstruction> CreateScatter( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + // Creates a kDomain instruction which delimits an HLO domain which have // the provided user and operand side metadata. static std::unique_ptr<HloInstruction> CreateDomain( @@ -1452,6 +1458,9 @@ class HloInstruction { // Delegates to HloGatherInstruction::gather_window_bounds. tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const; + // Delegates to HloScatterInstruction::scatter_dimension_numbers(). + const ScatterDimensionNumbers& scatter_dimension_numbers() const; + // Old methods kept for smooth subclassing transition END. protected: diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index b75a2bd34b..8a694dde80 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1425,6 +1425,55 @@ TEST_F(HloInstructionTest, StringifyGather_1) { "index_vector_dim=2, window_bounds={30,29,28,27,26}"); } +TEST_F(HloInstructionTest, StringifyScatter) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape scatter_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape scatter_updates_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Scatter"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* scatter_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, scatter_indices_tensor_shape, "scatter_indices")); + HloInstruction* scatter_updates = + builder.AddInstruction(HloInstruction::CreateParameter( + 2, scatter_updates_shape, "scatter_updates")); + + HloComputation::Builder update_builder("Scatter.update"); + update_builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p1")); + update_builder.AddInstruction( + HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {}), "p2")); + + auto module = CreateNewModule(); + auto* update_computation = + module->AddEmbeddedComputation(update_builder.Build()); + + HloInstruction* scatter_instruction = + builder.AddInstruction(HloInstruction::CreateScatter( + input_tensor_shape, input, scatter_indices, scatter_updates, + update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + scatter_instruction->ToString(), + "%scatter = f32[50,49,48,47,46]{4,3,2,1,0} " + "scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %scatter_indices, " + "f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %scatter_updates), " + "update_window_dims={4,5,6,7,8}, inserted_window_dims={}, " + "scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=2, " + "to_apply=%Scatter.update"); +} + TEST_F(HloInstructionTest, CanonnicalStringificationFusion) { // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index df26a2c744..a571fd574e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -2015,4 +2015,91 @@ std::unique_ptr<HloInstruction> HloGatherInstruction::CloneWithNewOperandsImpl( gather_window_bounds()); } +HloScatterInstruction::HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers) + : HloInstruction(HloOpcode::kScatter, shape) { + AppendOperand(operand); + AppendOperand(scatter_indices); + AppendOperand(updates); + AppendComputation(update_computation); + scatter_dimension_numbers_ = + MakeUnique<ScatterDimensionNumbers>(scatter_dim_numbers); +} + +string HloScatterInstruction::ScatterDimensionNumbersToString() const { + string update_window_dims = + StrCat("update_window_dims={", + Join(scatter_dimension_numbers().update_window_dims(), ","), "}"); + string inserted_window_dims = StrCat( + "inserted_window_dims={", + Join(scatter_dimension_numbers().inserted_window_dims(), ","), "}"); + string scatter_dims_to_operand_dims = StrCat( + "scatter_dims_to_operand_dims={", + Join(scatter_dimension_numbers().scatter_dims_to_operand_dims(), ","), + "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", scatter_dimension_numbers().index_vector_dim()); + + return Join<std::initializer_list<string>>( + {update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + +/* static */ ScatterDimensionNumbers +HloScatterInstruction::MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice<int64> update_window_dims, + tensorflow::gtl::ArraySlice<int64> inserted_window_dims, + tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims, + int64 index_vector_dim) { + ScatterDimensionNumbers scatter_dim_numbers; + for (int64 update_window_dim : update_window_dims) { + scatter_dim_numbers.add_update_window_dims(update_window_dim); + } + for (int64 inserted_window_dim : inserted_window_dims) { + scatter_dim_numbers.add_inserted_window_dims(inserted_window_dim); + } + for (int64 scatter_dim_to_operand_dim : scatter_dims_to_operand_dims) { + scatter_dim_numbers.add_scatter_dims_to_operand_dims( + scatter_dim_to_operand_dim); + } + scatter_dim_numbers.set_index_vector_dim(index_vector_dim); + return scatter_dim_numbers; +} + +HloInstructionProto HloScatterInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_scatter_dimension_numbers() = scatter_dimension_numbers(); + return proto; +} + +std::vector<string> HloScatterInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return {ScatterDimensionNumbersToString()}; +} + +bool HloScatterInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const { + const auto& casted_other = static_cast<const HloScatterInstruction&>(other); + return protobuf_util::ProtobufEquals( + scatter_dimension_numbers(), + casted_other.scatter_dimension_numbers()) && + eq_computations(to_apply(), casted_other.to_apply()); +} + +std::unique_ptr<HloInstruction> HloScatterInstruction::CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 3); + return MakeUnique<HloScatterInstruction>( + shape, new_operands[0], new_operands[1], new_operands[2], to_apply(), + scatter_dimension_numbers()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 132e767420..3797bef600 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1198,6 +1198,45 @@ class HloGatherInstruction : public HloInstruction { std::vector<int64> gather_window_bounds_; }; +class HloScatterInstruction : public HloInstruction { + public: + explicit HloScatterInstruction( + const Shape& shape, HloInstruction* operand, + HloInstruction* scatter_indices, HloInstruction* updates, + HloComputation* update_computation, + const ScatterDimensionNumbers& scatter_dim_numbers); + const ScatterDimensionNumbers& scatter_dimension_numbers() const { + CHECK(scatter_dimension_numbers_ != nullptr); + return *scatter_dimension_numbers_; + } + // Returns the dump string of the scatter dimension numbers. + string ScatterDimensionNumbersToString() const; + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + // Creates an instance of ScatterDimensionNumbers. + static ScatterDimensionNumbers MakeScatterDimNumbers( + tensorflow::gtl::ArraySlice<int64> update_window_dims, + tensorflow::gtl::ArraySlice<int64> inserted_window_dims, + tensorflow::gtl::ArraySlice<int64> scatter_dims_to_operand_dims, + int64 index_vector_dim); + + private: + std::vector<string> ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function<bool(const HloComputation*, const HloComputation*)>& + eq_computations) const override; + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( + const Shape& shape, + tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, + HloCloneContext* context) const override; + + std::unique_ptr<ScatterDimensionNumbers> scatter_dimension_numbers_; +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTIONS_H_ diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 59e9a5a94a..88531b6f20 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -118,6 +118,7 @@ namespace xla { V(kReverse, "reverse") \ V(kRng, "rng") \ V(kRoundNearestAfz, "round-nearest-afz") \ + V(kScatter, "scatter") \ V(kSelect, "select") \ V(kSelectAndScatter, "select-and-scatter") \ V(kSend, "send") \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 6ed6f74c55..3efa264259 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1242,6 +1242,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, dim_numbers, *window_bounds)); break; } + case HloOpcode::kScatter: { + // TODO(b/32945756): Implement HLO parsing for Scatter. + return TokenError("HLO parsing is not implemented for Scatter."); + } case HloOpcode::kDomain: { DomainData domain; attrs["domain"] = {/*required=*/true, AttrTy::kDomain, &domain}; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 25fa319faf..e4a5cd3af1 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -510,6 +510,15 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { gather->gather_dimension_numbers(), gather->gather_window_bounds())); } +Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + return CheckShape( + scatter, ShapeInference::InferScatterShape( + scatter->operand(0)->shape(), scatter->operand(1)->shape(), + scatter->operand(2)->shape(), + scatter->to_apply()->ComputeProgramShape(), + scatter->scatter_dimension_numbers())); +} + Status ShapeVerifier::HandleAfterAll(HloInstruction* token) { std::vector<const Shape*> operand_shapes; for (const HloInstruction* operand : token->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 79f7aa9f4c..7feddaeabf 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -83,6 +83,7 @@ class ShapeVerifier : public DfsHloVisitor { HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleGather(HloInstruction* gather) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleAfterAll(HloInstruction* token) override; Status FinishVisit(HloInstruction*) override { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index af07370135..e2191aedb7 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -141,6 +141,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReduceWindow: case HloOpcode::kRemainder: case HloOpcode::kRng: + case HloOpcode::kScatter: case HloOpcode::kSelectAndScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 35df792b07..20314ca482 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2568,4 +2568,192 @@ static Status ValidateGatherDimensionNumbers( return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); } +namespace { + +Status ValidateScatterDimensionNumbers( + const Shape& operand_shape, + tensorflow::gtl::ArraySlice<int64> scatter_indices_shape, + const Shape& updates_shape, const ScatterDimensionNumbers& dim_numbers) { + // Validate update_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.update_window_dims())) { + return InvalidArgument( + "update_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.update_window_dims()) != + dim_numbers.update_window_dims().end()) { + return InvalidArgument( + "update_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.update_window_dims(), ", ").c_str()); + } + const int64 updates_rank = ShapeUtil::Rank(updates_shape); + for (int64 window_dim : dim_numbers.update_window_dims()) { + if (window_dim < 0 || window_dim >= updates_rank) { + return InvalidArgument( + "Invalid update_window_dims set in scatter op; valid range is [0, " + "%lld). got: %lld.", + updates_rank, window_dim); + } + } + + // Validate inserted_window_dims in ScatterDimensionNumbers. + if (!c_is_sorted(dim_numbers.inserted_window_dims())) { + return InvalidArgument( + "inserted_window_dims in scatter op must be sorted; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + if (c_adjacent_find(dim_numbers.inserted_window_dims()) != + dim_numbers.inserted_window_dims().end()) { + return InvalidArgument( + "inserted_window_dims in scatter op must not repeat; got: %s.", + Join(dim_numbers.inserted_window_dims(), ", ").c_str()); + } + for (int64 inserted_dim : dim_numbers.inserted_window_dims()) { + if (inserted_dim < 0 || inserted_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid inserted_window_dims set in scatter op; valid range is [0, " + "%d), got: %lld.", + operand_shape.dimensions_size(), inserted_dim); + } + } + + // Validate scatter_dims_to_operand_dims in ScatterDimensionNumbers. + if (dim_numbers.scatter_dims_to_operand_dims_size() != + scatter_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Scatter op has %d elements in scatter_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of scatter_indices is %lld. " + "These two numbers must be equal.", + dim_numbers.scatter_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + scatter_indices_shape[dim_numbers.index_vector_dim()]); + } + for (int i = 0; i < dim_numbers.scatter_dims_to_operand_dims_size(); ++i) { + int64 scatter_dim_to_operand_dim = + dim_numbers.scatter_dims_to_operand_dims(i); + if (scatter_dim_to_operand_dim < 0 || + scatter_dim_to_operand_dim >= operand_shape.dimensions_size()) { + return InvalidArgument( + "Invalid scatter_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld.", + operand_shape.dimensions_size(), i, scatter_dim_to_operand_dim); + } + } + std::vector<int64> sorted_scatter_dims_to_operand_dims( + dim_numbers.scatter_dims_to_operand_dims().begin(), + dim_numbers.scatter_dims_to_operand_dims().end()); + c_sort(sorted_scatter_dims_to_operand_dims); + if (c_adjacent_find(sorted_scatter_dims_to_operand_dims) != + sorted_scatter_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims; " + "got: %s.", + Join(dim_numbers.scatter_dims_to_operand_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +} // namespace + +/*static*/ StatusOr<Shape> ShapeInference::InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers) { + TF_RETURN_IF_ERROR( + ExpectArray(operand_shape, "operand tensor of scatter op")); + TF_RETURN_IF_ERROR( + ExpectArray(scatter_indices_shape, "scatter indices of scatter op")); + TF_RETURN_IF_ERROR(ExpectArray(updates_shape, "updates of scatter op")); + + if (!ShapeUtil::ElementIsIntegral(scatter_indices_shape)) { + return InvalidArgument( + "Scatter indices parameter must be an integral tensor; got %s.", + ShapeUtil::HumanString(scatter_indices_shape).c_str()); + } + + if (scatter_indices_shape.dimensions_size() < + scatter_dim_numbers.index_vector_dim() || + scatter_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Scatter index leaf dimension must be within [0, rank(scatter_indices)" + " + 1). rank(scatter_indices) is %d and scatter index leaf dimension " + "is %lld.", + scatter_indices_shape.dimensions_size(), + scatter_dim_numbers.index_vector_dim()); + } + + // Check if the update computation has a proper shape as a reduction. + TF_RETURN_IF_ERROR(VerifyReducerShape( + to_apply_shape, ShapeUtil::MakeShape(operand_shape.element_type(), {}), + updates_shape.element_type())); + + std::vector<int64> expanded_scatter_indices_shape = + ArraySliceToVector(AsInt64Slice(scatter_indices_shape.dimensions())); + if (expanded_scatter_indices_shape.size() == + scatter_dim_numbers.index_vector_dim()) { + expanded_scatter_indices_shape.push_back(1); + } + + int64 expected_updates_rank = expanded_scatter_indices_shape.size() - 1 + + scatter_dim_numbers.update_window_dims_size(); + if (ShapeUtil::Rank(updates_shape) != expected_updates_rank) { + return InvalidArgument("Updates tensor must be of rank %lld; got %lld.", + expected_updates_rank, + ShapeUtil::Rank(updates_shape)); + } + + TF_RETURN_IF_ERROR(ValidateScatterDimensionNumbers( + operand_shape, expanded_scatter_indices_shape, updates_shape, + scatter_dim_numbers)); + + int64 inserted_dims_seen = 0; + std::vector<int64> max_update_window_bounds; + for (int i = 0; i < operand_shape.dimensions_size(); ++i) { + if (inserted_dims_seen < scatter_dim_numbers.inserted_window_dims_size() && + scatter_dim_numbers.inserted_window_dims(inserted_dims_seen) == i) { + ++inserted_dims_seen; + } else { + max_update_window_bounds.push_back(operand_shape.dimensions(i)); + } + } + for (int i = 0; i < scatter_dim_numbers.update_window_dims_size(); ++i) { + auto update_window_dim = scatter_dim_numbers.update_window_dims(i); + if (updates_shape.dimensions(update_window_dim) > + max_update_window_bounds[i]) { + return InvalidArgument( + "Bounds of the window dimensions of updates must not exceed the " + "bounds of the corresponding dimensions of operand. For dimension " + "%lld, updates bound is %lld, operand bound is %lld.", + update_window_dim, updates_shape.dimensions(update_window_dim), + max_update_window_bounds[i]); + } + } + + int64 scatter_dims_seen = 0; + for (int64 i = 0; i < ShapeUtil::Rank(updates_shape); ++i) { + bool is_update_window_dim = + c_binary_search(scatter_dim_numbers.update_window_dims(), i); + if (is_update_window_dim) { + continue; + } + if (scatter_dims_seen == scatter_dim_numbers.index_vector_dim()) { + ++scatter_dims_seen; + } + if (updates_shape.dimensions(i) != + expanded_scatter_indices_shape[scatter_dims_seen]) { + return InvalidArgument( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices. For " + "scatter dimension %lld, updates bound is %lld, scatter_indices " + "bound is %lld.", + i, updates_shape.dimensions(i), + expanded_scatter_indices_shape[scatter_dims_seen]); + } + ++scatter_dims_seen; + } + + return operand_shape; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 1a5684e3c3..6adea7bc1f 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -268,6 +268,14 @@ class ShapeInference { const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds); + // Helper that validates the given input shape, scatter indices shape, updates + // shape, and scatter dimension numbers that constitute a scatter operation, + // and returns the result shape of the scatter operation. + static StatusOr<Shape> InferScatterShape( + const Shape& operand_shape, const Shape& scatter_indices_shape, + const Shape& updates_shape, const ProgramShape& to_apply_shape, + const ScatterDimensionNumbers& scatter_dim_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 6046d50c6d..511d2c22f8 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1536,7 +1536,7 @@ TEST_F(ShapeInferenceTest, BadSort) { << statusor.status(); } -class GatherShapeInferenceTest : public ShapeInferenceTest { +class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); @@ -1553,9 +1553,13 @@ class GatherShapeInferenceTest : public ShapeInferenceTest { ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); + const ProgramShape to_apply_ = + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_); }; -TEST_F(GatherShapeInferenceTest, TensorFlowGather) { +// Shape inference tests for Gather. + +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, @@ -1570,7 +1574,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGather) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_vector_32_, @@ -1585,7 +1589,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) { TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, ShapeInference::InferGatherShape( matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, @@ -1600,7 +1604,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { +TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1617,7 +1621,7 @@ TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1635,7 +1639,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, ShapeInference::InferGatherShape( @@ -1653,7 +1657,7 @@ TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { +TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) { // This is equivalent to a dynamic slice. TF_ASSERT_OK_AND_ASSIGN( Shape gather_shape, @@ -1671,7 +1675,7 @@ TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { +TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) { // The gather indices "tensor" is a scalar S here that's used to slice out // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, @@ -1689,7 +1693,7 @@ TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { << ShapeUtil::HumanString(gather_shape); } -TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( tuple_shape_, s64_vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1704,7 +1708,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, tuple_shape_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1719,7 +1723,7 @@ TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( s64_vector_32_, vector_32_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1734,7 +1738,7 @@ TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1751,7 +1755,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowIndices) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1768,7 +1772,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1784,7 +1788,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1800,7 +1804,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1818,7 +1822,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1835,7 +1839,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1853,7 +1857,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1872,7 +1876,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1890,7 +1894,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1908,7 +1912,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1924,7 +1928,8 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { +TEST_F(ScatterGatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsTooLarge) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1940,7 +1945,7 @@ TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1958,7 +1963,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, +TEST_F(ScatterGatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, @@ -1975,7 +1980,7 @@ TEST_F(GatherShapeInferenceTest, << statusor.status(); } -TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { StatusOr<Shape> statusor = ShapeInference::InferGatherShape( f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, HloGatherInstruction::MakeGatherDimNumbers( @@ -1992,5 +1997,575 @@ TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { << statusor.status(); } +// Shape inference tests for Scatter. + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {64, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {10, 32}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, + ShapeUtil::MakeShape(F32, {32, 8}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndices) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterWithUpdatesNotMatchingIndicesV2) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{1}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Bounds of the window dimensions of updates must not exceed " + "the bounds of the corresponding dimensions of operand.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + TfScatterNdWithUpdatesNotMatchingIndices) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Bounds of the scatter dimensions of updates must be same as the " + "bounds of the corresponding dimensions of scatter indices.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4))); + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}), + to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 8}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) { + // This is equivalent to a dynamic update slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3, 4}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) { + // The scalar indices "tensor" is a scalar S here that's used to update a + // [30,29,28,27] shaped tensor within the operand at position S. + TF_ASSERT_OK_AND_ASSIGN( + Shape scatter_shape, + ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0, 1, 2, 3}, + /*inserted_window_dims=*/{0}, + /*scatter_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0))); + + EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_)) + << ShapeUtil::HumanString(scatter_shape); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for operand")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + ScatterWithTupleShapedScatterIndicesInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for scatter indices")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected array argument for updates")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + s64_vector_32_, vector_32_, s64_vector_32_, to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{0}, + /*inserted_window_dims=*/{1}, + /*scatter_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/10)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Scatter index leaf dimension must be within [0, " + "rank(scatter_indices) + 1)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Updates tensor must be of rank 7; got 8.")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) { + const ProgramShape invalid_update_computation = + ShapeUtil::MakeProgramShape({f32_}, f32_); + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), + invalid_update_computation, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Reduction function must take 2 parameters, but takes 1")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 8, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedUpdateWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 7}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("update_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6, 7, 9}, + /*inserted_window_dims=*/{}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid update_window_dims set in scatter op; valid " + "range is [0, 9)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{2, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must be sorted")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedInsertedWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 1}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("inserted_window_dims in scatter op must not repeat")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 5}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid inserted_window_dims set in scatter op; valid " + "range is [0, 5)")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of scatter_indices " + "is 5. These two numbers must be equal")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain " + "is [0, 5), got: 4->10")) + << statusor.status(); +} + +TEST_F(ScatterGatherShapeInferenceTest, + InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) { + StatusOr<Shape> statusor = ShapeInference::InferScatterShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_, + HloScatterInstruction::MakeScatterDimNumbers( + /*update_window_dims=*/{4, 5, 6}, + /*inserted_window_dims=*/{1, 2}, + /*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3}, + /*index_vector_dim=*/4)); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in scatter_dims_to_operand_dims")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 0b300dc7b2..fd784e909c 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -447,6 +447,20 @@ message GatherDimensionNumbers { int64 index_vector_dim = 4; } +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + message ConvolutionDimensionNumbers { // The number of the dimension that represents batch in the input. int64 input_batch_dimension = 7; diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 5f7482f90f..3981aaaf75 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1801,6 +1801,138 @@ is implementation-defined. : : : limit of interval : | `shape` | `Shape` | Output shape of type T | +## Scatter + +The XLA scatter operation generates a result which is the value of the input +tensor `operand`, with several slices (at indices specified by +`scatter_indices`) updated with the values in `updates` using +`update_computation`. + +See also +[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_client/xla_builder.h). + +<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b> + +|Arguments | Type | Semantics | +|------------------|------------------------|----------------------------------| +|`operand` | `XlaOp` | Tensor to be scattered into. | +|`scatter_indices` | `XlaOp` | Tensor containing the starting | +: : : indices of the slices that must : +: : : be scattered to. : +|`updates` | `XlaOp` | Tensor containing the values that| +: : : must be used for scattering. : +|`update_computation`| `XlaComputation` | Computation to be used for | +: : : combining the existing values in : +: : : the input tensor and the updates : +: : : during scatter. This computation : +: : : should be of type `T, T -> T`. : +|`index_vector_dim`| `int64` | The dimension in | +: : : `scatter_indices` that contains : +: : : the starting indices. : +|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in | +: : : `updates` shape that are _window : +: : : dimensions_. : +|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ | +: : : that must be inserted into : +: : : `updates` shape. : +|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from | +: : : the scatter indices to the : +: : : operand index space. This array : +: : : is interpreted as mapping `i` to : +: : : `scatter_dims_to_operand_dims[i]`: +: : : . It has to be one-to-one and : +: : : total. : + +If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider +`scatter_indices` to have a trailing `1` dimension. + +We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of +dimensions in `updates` shape that are not in `update_window_dims`, in ascending +order. + +The arguments of scatter should follow these constraints: + + - `updates` tensor must be of rank `update_window_dims.size + + scatter_indices.rank - 1`. + + - Bounds of dimension `i` in `updates` must conform to the following: + - If `i` is present in `update_window_dims` (i.e. equal to + `update_window_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must not exceed the corresponding bound of `operand` + after accounting for the `inserted_window_dims` (i.e. + `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains + the bounds of `operand` with the bounds at indices + `inserted_window_dims` removed). + - If `i` is present in `update_scatter_dims` (i.e. equal to + `update_scatter_dims`[`k`] for some `k`), then the bound of dimension + `i` in `updates` must be equal to the corresponding bound of + `scatter_indices`, skipping `index_vector_dim` (i.e. + `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and + `scatter_indices.shape.dims`[`k+1`] otherwise). + + - `update_window_dims` must be in ascending order, not have any repeating + dimension numbers, and be in the range `[0, updates.rank)`. + + - `inserted_window_dims` must be in ascending order, not have any + repeating dimension numbers, and be in the range `[0, operand.rank)`. + + - `scatter_dims_to_operand_dims.size` must be equal to + `scatter_indices`[`index_vector_dim`], and its values must be in the range + `[0, operand.rank)`. + +For a given index `U` in the `updates` tensor, the corresponding index `I` in +the `operand` tensor into which this update has to be applied is computed as +follows: + + 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up + an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] = + `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at + positions `index_vector_dim` into A. + 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering + `S` using the `scatter_dims_to_operand_dims` map. More formally: + 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if + `k` < `scatter_dims_to_operand_dims.size`. + 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise. + 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices + at `update_window_dims` in `U` according to `inserted_window_dims`. + More formally: + 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if + `k` < `update_window_dims.size`, where `window_dims_to_operand_dims` + is the monotonic function with domain [`0`, `update_window_dims.size`) + and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For + example, if `update_window_dims.size` is `4`, `operand.rank` is `6`, + and `inserted_window_dims` is {`0`, `2`} then + `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, + `3`→`5`}). + 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise. + 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise + addition. + +In summary, the scatter operation can be defined as follows. + + - Initialize `output` with `operand`, i.e. for all indices `O` in the + `operand` tensor:\ + `output`[`O`] = `operand`[`O`] + - For every index `U` in the `updates` tensor and the corresponding index `O` + in the `operand` tensor:\ + `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`]) + +The order in which updates are applied is non-deterministic. So, when multiple +indices in `updates` refer to the same index in `operand`, the corresponding +value in `output` will be non-deterministic. + +Note that the first parameter that is passed into the `update_computation` will +always be the current value from the `output` tensor and the second parameter +will always be the value from the `updates` tensor. This is important +specifically for cases when the `update_computation` is _not commutative_. + +Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e. +the scatter op updates the elements in the input that are extracted by the +corresponding gather op. + +For a detailed informal description and examples, refer to the +"Informal Description" section under `Gather`. + ## Select See also |