diff options
Diffstat (limited to 'tensorflow/compiler/xla/service')
22 files changed, 779 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 9f5f2f96b7..0ceb9ca9e0 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -145,7 +145,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 5b09e4931e..56723e7650 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -214,6 +214,7 @@ class DfsHloVisitorBase { virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; + virtual Status HandleGather(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index ffc4f3bb79..ecda5288ee 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSendDone(HloInstructionPtr send_done) override { return DefaultAction(send_done); } + Status HandleGather(HloInstructionPtr gather) override { + return DefaultAction(gather); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index aa2a0a9800..30c88c0a5d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2064,6 +2064,11 @@ GetHloBufferSlices(const HloInstruction* hlo, return slices; } +Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { + // TODO(b/72710576): Gather is not implemented on GPUs + return Unimplemented("Gather is not implemented on GPUs."); +} + std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { const BufferAssignment& buffer_assn = diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 688760efbd..b83a2337e2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,6 +67,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; + Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 36db711c6c..a43785b4a9 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -129,6 +129,10 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + + // Gather dimension numbers. + xla.GatherDimensionNumbers gather_dimension_numbers = 33; + repeated int64 gather_window_bounds = 34; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6a4651d83f..4ec2ef27bf 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -533,6 +533,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { return Status::OK(); } +Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather does not issue any flops. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index af52ea06ca..d17678d20f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -100,6 +100,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; + Status HandleGather(const HloInstruction* gather) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 9b0e2fd7d6..2861fec39e 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -940,6 +940,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0d925ad00d..b7dd055d7c 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1155,6 +1155,38 @@ bool HloInstruction::HasSideEffect() const { return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) { + std::unique_ptr<HloInstruction> instruction = + WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(gather_indices); + instruction->gather_dimension_numbers_ = + MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); + return instruction; +} + +/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + return gather_dim_numbers; +} + std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> new_operands, @@ -1397,6 +1429,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( CHECK_EQ(new_operands.size(), 1); clone = CreateRecvDone(new_operands[0]); break; + case HloOpcode::kGather: + CHECK_EQ(new_operands.size(), 2); + clone = CreateGather(shape, new_operands[0], new_operands[1], + *gather_dimension_numbers_, gather_window_bounds_); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1740,6 +1777,11 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); + case HloOpcode::kGather: + return protobuf_util::ProtobufEquals(gather_dimension_numbers(), + other.gather_dimension_numbers()) && + gather_window_bounds() == other.gather_window_bounds(); + // FFT has various types & lengths. case HloOpcode::kFft: return fft_type() == other.fft_type() && @@ -2171,6 +2213,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } + if (gather_dimension_numbers_ != nullptr) { + extra.push_back(GatherDimensionNumbersToString()); + extra.push_back( + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); + } if (opcode() == HloOpcode::kFft) { extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); @@ -2302,6 +2349,14 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + if (opcode() == HloOpcode::kGather) { + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2618,6 +2673,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleSend(this); case HloOpcode::kSendDone: return visitor->HandleSendDone(this); + case HloOpcode::kGather: + return visitor->HandleGather(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3301,6 +3358,23 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +string HloInstruction::GatherDimensionNumbersToString() const { + CHECK_NE(gather_dimension_numbers_.get(), nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + + return Join<std::initializer_list<string>>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims}, + ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index e898a83739..1762d227be 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -451,6 +451,12 @@ class HloInstruction { HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); + static std::unique_ptr<HloInstruction> CreateGather( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -492,6 +498,12 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions); + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice<int64> output_window_dims, + tensorflow::gtl::ArraySlice<int64> elided_window_dims, + tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -1102,6 +1114,19 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + + tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const { + CHECK_EQ(opcode(), HloOpcode::kGather); + return gather_window_bounds_; + } + + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1366,6 +1391,9 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_; + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_; + std::vector<int64> gather_window_bounds_; + // Describes FFT type for an FFT instruction. FftType fft_type_ = FftType::FFT; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 94e9bfe56e..32d3ed272b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1271,5 +1271,40 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, StringifyGather) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "window_bounds={30,29,28,27,26}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 088dd15dbf..af24604c39 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -76,6 +76,7 @@ namespace xla { V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index f3378309c2..b1fd068115 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -424,6 +424,14 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace +Status ShapeVerifier::HandleGather(HloInstruction* gather) { + return CheckShape( + gather, + ShapeInference::InferGatherShape( + gather->operand(0)->shape(), gather->operand(1)->shape(), + gather->gather_dimension_numbers(), gather->gather_window_bounds())); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, const Shape& inferred_shape) { // If allow_mixed_precision_ is false, check if there are operands with diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index f9f898c236..1dd7ec3c51 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -80,6 +80,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormInference( HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + Status HandleGather(HloInstruction* gather) override; Status FinishVisit(HloInstruction*) override { return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f08d809d79..f494748e17 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,7 @@ namespace xla { case HloOpcode::kExp: case HloOpcode::kFft: case HloOpcode::kFusion: + case HloOpcode::kGather: case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 95c853b5c4..e278eab690 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1446,6 +1446,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kFftRequest: handle_status = computation->AddFftInstruction(arg->fft_request()); break; + case OpRequest::kGatherRequest: + handle_status = computation->AddGatherInstruction(arg->gather_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 004889b5f2..c9692757b2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2448,4 +2448,197 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return to_apply.result(); } +static Status ValidateGatherDimensionNumbers( + const Shape& input_shape, + tensorflow::gtl::ArraySlice<int64> gather_indices_shape, + const GatherDimensionNumbers& dim_numbers) { + if (!c_is_sorted(dim_numbers.output_window_dims())) { + return InvalidArgument( + "Output window dimensions in gather op must be ascending; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.output_window_dims()) != + dim_numbers.output_window_dims().end()) { + return InvalidArgument( + "Output window dimensions in gather op must not repeat; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_shape_rank = + output_window_dim_count + gather_indices_shape.size(); + + for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { + int64 window_index = dim_numbers.output_window_dims(i); + if (window_index < 0 || window_index >= output_shape_rank) { + return InvalidArgument( + "Window index %d in gather op is out of bounds; got %lld, but should " + "have been in" + "[0,%lld)", + i, window_index, output_shape_rank); + } + } + + if (dim_numbers.gather_dims_to_operand_dims_size() != + gather_indices_shape.back()) { + return InvalidArgument( + "There must be exactly as many elements in gather_dims_to_operand_dims " + "as there are elements in the last dimension of %%gather_indices; got: " + "%d, expected %lld", + dim_numbers.gather_dims_to_operand_dims_size(), + gather_indices_shape.back()); + } + + for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { + int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); + if (gather_dim_to_input_dim < 0 || + gather_dim_to_input_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld", + input_shape.dimensions_size(), i, gather_dim_to_input_dim); + } + } + + std::vector<int64> sorted_gather_dims_to_operand_dims( + dim_numbers.gather_dims_to_operand_dims().begin(), + dim_numbers.gather_dims_to_operand_dims().end()); + + c_sort(sorted_gather_dims_to_operand_dims); + + if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != + sorted_gather_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "got: %s", + Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + } + + for (int64 elided_dim : dim_numbers.elided_window_dims()) { + if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid elided_window_dims set in gather op; valid range is [0, " + "%d), got: %lld", + input_shape.dimensions_size(), elided_dim); + } + } + + if (!c_is_sorted(dim_numbers.elided_window_dims())) { + return InvalidArgument( + "elided_window_dims in gather op must be sorted; got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.elided_window_dims()) != + dim_numbers.elided_window_dims().end()) { + return InvalidArgument( + "Repeated dimensions not allowed in elided_window_dims in gather op; " + "got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + gather_indices_shape, "gather indices operand of gather op")); + + if (gather_indices_shape.dimensions_size() < 1) { + return InvalidArgument( + "Gather indices parameter must at least of rank 1; got %s", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + return InvalidArgument( + "Gather indices parameter must be an integral tensor; got %s", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + std::vector<int64> expanded_gather_indices_shape; + // We implicitly reshape gather indices of shape P[N] to P[N,1]. + expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(expanded_gather_indices_shape)); + if (expanded_gather_indices_shape.size() == 1) { + expanded_gather_indices_shape.push_back(1); + } + + TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( + input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + + if (window_bounds.size() != input_shape.dimensions_size()) { + return InvalidArgument( + "Gather op must have one window bound for every input dimension; got: " + "len(window_bounds)=%lu, input_shape.rank=%d", + window_bounds.size(), input_shape.dimensions_size()); + } + + if (window_bounds.size() != + gather_dim_numbers.output_window_dims_size() + + gather_dim_numbers.elided_window_dims_size()) { + return InvalidArgument( + "All components of the window index in a gather op must either be a " + "output window index or explicitly elided; got len(window_bounds)=%lu, " + "output_window_bounds=%s, elided_window_bounds=%s", + window_bounds.size(), + Join(gather_dim_numbers.output_window_dims(), ",").c_str(), + Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + } + + for (int i = 0; i < window_bounds.size(); i++) { + int64 window_bound = window_bounds[i]; + int64 corresponding_input_bound = input_shape.dimensions(i); + if (window_bound < 0 || window_bound > corresponding_input_bound) { + return InvalidArgument( + "Window bound at index %d in gather op is out of range, must be " + "within " + "[0, %lld), got %lld", + i, corresponding_input_bound + 1, window_bound); + } + } + + for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { + if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + return InvalidArgument( + "Gather op can only elide window indices with bound 1, but bound is " + "%lld for index %lld at position %d", + window_bounds[gather_dim_numbers.elided_window_dims(i)], + gather_dim_numbers.elided_window_dims(i), i); + } + } + + int64 result_rank = gather_dim_numbers.output_window_dims_size() + + (expanded_gather_indices_shape.size() - 1); + int64 window_dims_seen = 0; + int64 gather_dims_seen = 0; + std::vector<int64> output_dim_bounds; + output_dim_bounds.reserve(result_rank); + for (int64 i = 0; i < result_rank; i++) { + int64 current_bound; + bool is_window_index = + c_binary_search(gather_dim_numbers.output_window_dims(), i); + if (is_window_index) { + while (c_binary_search(gather_dim_numbers.elided_window_dims(), + window_dims_seen)) { + window_dims_seen++; + } + current_bound = window_bounds[window_dims_seen++]; + } else { + current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + } + + output_dim_bounds.push_back(current_bound); + } + + return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index c4a1da28f3..0d3045213d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -253,6 +253,14 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers); + // Helper that infers the shape of the tensor produced by a gather operation + // with the given input shape, gather indices shape and gather dimension + // numbers. + static StatusOr<Shape> InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice<int64> window_bounds); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 026c021165..7eb120843f 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -18,15 +18,16 @@ limitations under the License. #include <string> #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -1527,5 +1528,341 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } +class GatherShapeInferenceTest : public ShapeInferenceTest { + protected: + const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); + const Shape s64_4d_tensor_10_9_8_7_1_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); + const Shape s64_4d_tensor_10_9_8_7_5_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape f32_5d_tensor_50_49_48_47_46_ = + ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); +}; + +TEST_F(GatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26})); + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + tuple_shape_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for input")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + s64_vector_32_, tuple_shape_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for gather indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + s64_vector_32_, s32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must at least of rank 1")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + s64_vector_32_, vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingWindowIndices) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 8, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must be ascending")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowIndices) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must not repeat")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexOutOfBounds) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 99, 100, 101}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 2 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingElidedWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{4}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("All components of the window index in a gather op must either " + "be a output window index or explicitly elided")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 19}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid elided_window_dims set in gather op; valid " + "range is [0, 5), got: 19")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 3}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in elided_window_dims in gather op")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "There must be exactly as many elements in " + "gather_dims_to_operand_dims " + "as there are elements in the last dimension of %gather_indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " + "[0, 5), got: 4->7")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{2, 1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{1, 1, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("elided_window_dims in gather op must be sorted")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{2}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 1, 300, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window bound at index 3 in gather op is out of range, " + "must be within [0, 48), got 300")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Gather op must have one window bound for every input dimension")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { + StatusOr<Shape> statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}), + /*window_bounds=*/{30, 29, 28, 26, 20}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op can only elide window indices with bound 1, " + "but bound is 29 for index 1 at position 0")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index d42cb6cdf3..4a55e4095a 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -315,6 +315,36 @@ StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction( return handle; } +StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction( + const GatherRequest& gather_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, + LookUpRequest(gather_request.input())); + TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, + LookUpRequest(gather_request.gather_indices())); + + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferGatherShape( + input_request->output_shape(), gather_indices_request->output_shape(), + gather_request.dimension_numbers(), + AsInt64Slice(gather_request.window_bounds()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_gather_request() = gather_request; + + VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << gather_request.ShortDebugString(); + return handle; +} + StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction( const GetTupleElementRequest& get_tuple_element_request) { tensorflow::mutex_lock lock(mutex_); @@ -2018,6 +2048,16 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kGatherRequest: { + PureFunctionalVisitor(session_computation, + request.request().gather_request().input(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + request.request().gather_request().gather_indices(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -2720,6 +2760,13 @@ static void ForEachOperand( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + apply(gather_request.input()); + apply(gather_request.gather_indices()); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -3453,6 +3500,20 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + HloInstruction* input_operand = + lookup_instruction(gather_request.input()); + HloInstruction* gather_indices_operand = + lookup_instruction(gather_request.gather_indices()); + std::vector<int64> window_bounds; + c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); + hlo_instruction = add_instruction(HloInstruction::CreateGather( + request.output_shape(), input_operand, gather_indices_operand, + gather_request.dimension_numbers(), window_bounds)); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 81a72583f7..fd5a2ace9b 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -242,6 +242,10 @@ class UserComputation { StatusOr<ComputationDataHandle> AddRecvInstruction( const RecvRequest& recv_request); + // Enqueues a Gather instruction onto this user computation. + StatusOr<ComputationDataHandle> AddGatherInstruction( + const GatherRequest& gather_request); + // Returns the user-provided name of this user computation, which is provided // via the XLA computation-building API. const string& name() const { return name_; } |