aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-02-16 15:29:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 15:33:12 -0800
commitba019dc689d6393d8dba04ca57e8b01b374db14f (patch)
tree6bd132cd6a1d3b6c8c833cb3e575db571ebd19a1 /tensorflow/compiler/xla/service
parent1873ed4faab980ad239c06e8b92b8f4a85154fe3 (diff)
[XLA] Add some plumbing, documentation, verification and shape inference for Gather
Pretty much everything other than HLO verification and shape inference will fail for Gather with Unimplemented. Note that this CL is intentionally incomplete -- I figured it would be nicer to get some of the boiler-platey stuff out of the way early. Let me know if you want me to send in a larger but more complete CL instead. PiperOrigin-RevId: 186055521
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc74
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc193
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h8
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc341
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc61
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
22 files changed, 779 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9f5f2f96b7..0ceb9ca9e0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -145,7 +145,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 5b09e4931e..56723e7650 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -214,6 +214,7 @@ class DfsHloVisitorBase {
virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
+ virtual Status HandleGather(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index ffc4f3bb79..ecda5288ee 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleSendDone(HloInstructionPtr send_done) override {
return DefaultAction(send_done);
}
+ Status HandleGather(HloInstructionPtr gather) override {
+ return DefaultAction(gather);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index aa2a0a9800..30c88c0a5d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2064,6 +2064,11 @@ GetHloBufferSlices(const HloInstruction* hlo,
return slices;
}
+Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
+ // TODO(b/72710576): Gather is not implemented on GPUs
+ return Unimplemented("Gather is not implemented on GPUs.");
+}
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
const HloInstruction* inst) {
const BufferAssignment& buffer_assn =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 688760efbd..b83a2337e2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -67,6 +67,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleDot(HloInstruction* dot) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleGather(HloInstruction* gather) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleSelectAndScatter(HloInstruction* instruction) override;
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 36db711c6c..a43785b4a9 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -129,6 +129,10 @@ message HloInstructionProto {
// FFT length.
repeated int64 fft_length = 32;
+
+ // Gather dimension numbers.
+ xla.GatherDimensionNumbers gather_dimension_numbers = 33;
+ repeated int64 gather_window_bounds = 34;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 6a4651d83f..4ec2ef27bf 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -533,6 +533,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
return Status::OK();
}
+Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
+ // Gather does not issue any flops.
+ return Status::OK();
+}
+
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index af52ea06ca..d17678d20f 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -100,6 +100,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
+ Status HandleGather(const HloInstruction* gather) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 9b0e2fd7d6..2861fec39e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -940,6 +940,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kConcatenate:
case HloOpcode::kCopy:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kGather:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 0d925ad00d..b7dd055d7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1155,6 +1155,38 @@ bool HloInstruction::HasSideEffect() const {
return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
+ const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ std::unique_ptr<HloInstruction> instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(gather_indices);
+ instruction->gather_dimension_numbers_ =
+ MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
+ c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
+ return instruction;
+}
+
+/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) {
+ GatherDimensionNumbers gather_dim_numbers;
+ for (int64 output_window_dim : output_window_dims) {
+ gather_dim_numbers.add_output_window_dims(output_window_dim);
+ }
+ for (int64 elided_window_dim : elided_window_dims) {
+ gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ }
+ for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
+ gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ }
+
+ return gather_dim_numbers;
+}
+
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
@@ -1397,6 +1429,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateRecvDone(new_operands[0]);
break;
+ case HloOpcode::kGather:
+ CHECK_EQ(new_operands.size(), 2);
+ clone = CreateGather(shape, new_operands[0], new_operands[1],
+ *gather_dimension_numbers_, gather_window_bounds_);
+ break;
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1740,6 +1777,11 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
+ case HloOpcode::kGather:
+ return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
+ other.gather_dimension_numbers()) &&
+ gather_window_bounds() == other.gather_window_bounds();
+
// FFT has various types & lengths.
case HloOpcode::kFft:
return fft_type() == other.fft_type() &&
@@ -2171,6 +2213,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
+ if (gather_dimension_numbers_ != nullptr) {
+ extra.push_back(GatherDimensionNumbersToString());
+ extra.push_back(
+ StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
+ }
if (opcode() == HloOpcode::kFft) {
extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
@@ -2302,6 +2349,14 @@ HloInstructionProto HloInstruction::ToProto() const {
if (dot_dimension_numbers_ != nullptr) {
*proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
}
+ if (gather_dimension_numbers_ != nullptr) {
+ *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
+ }
+ if (opcode() == HloOpcode::kGather) {
+ for (int64 bound : gather_window_bounds()) {
+ proto.add_gather_window_bounds(bound);
+ }
+ }
for (int i = 0; i < slice_starts_.size(); ++i) {
auto* slice_dimension = proto.add_slice_dimensions();
slice_dimension->set_start(slice_starts_[i]);
@@ -2618,6 +2673,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSend(this);
case HloOpcode::kSendDone:
return visitor->HandleSendDone(this);
+ case HloOpcode::kGather:
+ return visitor->HandleGather(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -3301,6 +3358,23 @@ string HloInstruction::DotDimensionNumbersToString() const {
return Join(result, ", ");
}
+string HloInstruction::GatherDimensionNumbersToString() const {
+ CHECK_NE(gather_dimension_numbers_.get(), nullptr);
+ string output_window_dims =
+ StrCat("output_window_dims={",
+ Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
+ string elided_window_dims =
+ StrCat("elided_window_dims={",
+ Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
+ string gather_dims_to_operand_dims = StrCat(
+ "gather_dims_to_operand_dims={",
+ Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+
+ return Join<std::initializer_list<string>>(
+ {output_window_dims, elided_window_dims, gather_dims_to_operand_dims},
+ ", ");
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e898a83739..1762d227be 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -451,6 +451,12 @@ class HloInstruction {
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation);
+ static std::unique_ptr<HloInstruction> CreateGather(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
// Creates a fusion instruction. A fusion instruction contains one or more
// fused instructions forming an expression with a single root
// "fused_root". Additional instructions can be added to the fusion
@@ -492,6 +498,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates an instance of GatherDimensionNumbers.
+ static GatherDimensionNumbers MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
+
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
@@ -1102,6 +1114,19 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
+ const GatherDimensionNumbers& gather_dimension_numbers() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ return *gather_dimension_numbers_;
+ }
+
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
+ CHECK_EQ(opcode(), HloOpcode::kGather);
+ return gather_window_bounds_;
+ }
+
+ // Returns the dump string of the gather dimension numbers.
+ string GatherDimensionNumbersToString() const;
+
// Returns the random distribution for this rng node.
//
// Precondition: opcode() == HloOpcode::kRng
@@ -1366,6 +1391,9 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
+ std::vector<int64> gather_window_bounds_;
+
// Describes FFT type for an FFT instruction.
FftType fft_type_ = FftType::FFT;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 94e9bfe56e..32d3ed272b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1271,5 +1271,40 @@ TEST_F(HloInstructionTest, Stringification) {
"true_computation=%TransposeDot, false_computation=%TransposeDot");
}
+TEST_F(HloInstructionTest, StringifyGather) {
+ Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ Shape gather_indices_tensor_shape =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ Shape gather_result_shape =
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
+
+ HloComputation::Builder builder("Gather");
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+ HloInstruction* gather_indices =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, gather_indices_tensor_shape, "gather_indices"));
+
+ HloInstruction* gather_instruction =
+ builder.AddInstruction(HloInstruction::CreateGather(
+ gather_result_shape, input, gather_indices,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(gather_instruction->ToString(),
+ "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
+ "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+ "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
+ "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
+ "gather_dims_to_operand_dims={0,1,2,3,4}, "
+ "window_bounds={30,29,28,27,26}");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 088dd15dbf..af24604c39 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -76,6 +76,7 @@ namespace xla {
V(kFft, "fft") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
+ V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f3378309c2..b1fd068115 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -424,6 +424,14 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
} // namespace
+Status ShapeVerifier::HandleGather(HloInstruction* gather) {
+ return CheckShape(
+ gather,
+ ShapeInference::InferGatherShape(
+ gather->operand(0)->shape(), gather->operand(1)->shape(),
+ gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+}
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index f9f898c236..1dd7ec3c51 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -80,6 +80,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleBatchNormInference(
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
+ Status HandleGather(HloInstruction* gather) override;
Status FinishVisit(HloInstruction*) override {
return tensorflow::Status::OK();
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index f08d809d79..f494748e17 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -102,6 +102,7 @@ namespace xla {
case HloOpcode::kExp:
case HloOpcode::kFft:
case HloOpcode::kFusion:
+ case HloOpcode::kGather:
case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kMap:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 95c853b5c4..e278eab690 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1446,6 +1446,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
case OpRequest::kFftRequest:
handle_status = computation->AddFftInstruction(arg->fft_request());
break;
+ case OpRequest::kGatherRequest:
+ handle_status = computation->AddGatherInstruction(arg->gather_request());
+ break;
case OpRequest::kGetTupleElementRequest:
handle_status = computation->AddGetTupleElementInstruction(
arg->get_tuple_element_request());
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 004889b5f2..c9692757b2 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2448,4 +2448,197 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return to_apply.result();
}
+static Status ValidateGatherDimensionNumbers(
+ const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ const GatherDimensionNumbers& dim_numbers) {
+ if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ return InvalidArgument(
+ "Output window dimensions in gather op must be ascending; got: %s",
+ Join(dim_numbers.output_window_dims(), ", ").c_str());
+ }
+
+ if (c_adjacent_find(dim_numbers.output_window_dims()) !=
+ dim_numbers.output_window_dims().end()) {
+ return InvalidArgument(
+ "Output window dimensions in gather op must not repeat; got: %s",
+ Join(dim_numbers.output_window_dims(), ", ").c_str());
+ }
+
+ const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_shape_rank =
+ output_window_dim_count + gather_indices_shape.size();
+
+ for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
+ int64 window_index = dim_numbers.output_window_dims(i);
+ if (window_index < 0 || window_index >= output_shape_rank) {
+ return InvalidArgument(
+ "Window index %d in gather op is out of bounds; got %lld, but should "
+ "have been in"
+ "[0,%lld)",
+ i, window_index, output_shape_rank);
+ }
+ }
+
+ if (dim_numbers.gather_dims_to_operand_dims_size() !=
+ gather_indices_shape.back()) {
+ return InvalidArgument(
+ "There must be exactly as many elements in gather_dims_to_operand_dims "
+ "as there are elements in the last dimension of %%gather_indices; got: "
+ "%d, expected %lld",
+ dim_numbers.gather_dims_to_operand_dims_size(),
+ gather_indices_shape.back());
+ }
+
+ for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
+ int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
+ if (gather_dim_to_input_dim < 0 ||
+ gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
+ "got: %d->%lld",
+ input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ }
+ }
+
+ std::vector<int64> sorted_gather_dims_to_operand_dims(
+ dim_numbers.gather_dims_to_operand_dims().begin(),
+ dim_numbers.gather_dims_to_operand_dims().end());
+
+ c_sort(sorted_gather_dims_to_operand_dims);
+
+ if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
+ sorted_gather_dims_to_operand_dims.end()) {
+ return InvalidArgument(
+ "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "got: %s",
+ Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ }
+
+ for (int64 elided_dim : dim_numbers.elided_window_dims()) {
+ if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "%d), got: %lld",
+ input_shape.dimensions_size(), elided_dim);
+ }
+ }
+
+ if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ return InvalidArgument(
+ "elided_window_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ }
+
+ if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
+ dim_numbers.elided_window_dims().end()) {
+ return InvalidArgument(
+ "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "got: %s",
+ Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ }
+
+ return Status::OK();
+}
+
+/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
+ const Shape& input_shape, const Shape& gather_indices_shape,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ gather_indices_shape, "gather indices operand of gather op"));
+
+ if (gather_indices_shape.dimensions_size() < 1) {
+ return InvalidArgument(
+ "Gather indices parameter must at least of rank 1; got %s",
+ ShapeUtil::HumanString(gather_indices_shape).c_str());
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ return InvalidArgument(
+ "Gather indices parameter must be an integral tensor; got %s",
+ ShapeUtil::HumanString(gather_indices_shape).c_str());
+ }
+
+ std::vector<int64> expanded_gather_indices_shape;
+ // We implicitly reshape gather indices of shape P[N] to P[N,1].
+ expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
+ c_copy(gather_indices_shape.dimensions(),
+ std::back_inserter(expanded_gather_indices_shape));
+ if (expanded_gather_indices_shape.size() == 1) {
+ expanded_gather_indices_shape.push_back(1);
+ }
+
+ TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
+ input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+
+ if (window_bounds.size() != input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Gather op must have one window bound for every input dimension; got: "
+ "len(window_bounds)=%lu, input_shape.rank=%d",
+ window_bounds.size(), input_shape.dimensions_size());
+ }
+
+ if (window_bounds.size() !=
+ gather_dim_numbers.output_window_dims_size() +
+ gather_dim_numbers.elided_window_dims_size()) {
+ return InvalidArgument(
+ "All components of the window index in a gather op must either be a "
+ "output window index or explicitly elided; got len(window_bounds)=%lu, "
+ "output_window_bounds=%s, elided_window_bounds=%s",
+ window_bounds.size(),
+ Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
+ Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ }
+
+ for (int i = 0; i < window_bounds.size(); i++) {
+ int64 window_bound = window_bounds[i];
+ int64 corresponding_input_bound = input_shape.dimensions(i);
+ if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ return InvalidArgument(
+ "Window bound at index %d in gather op is out of range, must be "
+ "within "
+ "[0, %lld), got %lld",
+ i, corresponding_input_bound + 1, window_bound);
+ }
+ }
+
+ for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
+ if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ return InvalidArgument(
+ "Gather op can only elide window indices with bound 1, but bound is "
+ "%lld for index %lld at position %d",
+ window_bounds[gather_dim_numbers.elided_window_dims(i)],
+ gather_dim_numbers.elided_window_dims(i), i);
+ }
+ }
+
+ int64 result_rank = gather_dim_numbers.output_window_dims_size() +
+ (expanded_gather_indices_shape.size() - 1);
+ int64 window_dims_seen = 0;
+ int64 gather_dims_seen = 0;
+ std::vector<int64> output_dim_bounds;
+ output_dim_bounds.reserve(result_rank);
+ for (int64 i = 0; i < result_rank; i++) {
+ int64 current_bound;
+ bool is_window_index =
+ c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ if (is_window_index) {
+ while (c_binary_search(gather_dim_numbers.elided_window_dims(),
+ window_dims_seen)) {
+ window_dims_seen++;
+ }
+ current_bound = window_bounds[window_dims_seen++];
+ } else {
+ current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ }
+
+ output_dim_bounds.push_back(current_bound);
+ }
+
+ return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c4a1da28f3..0d3045213d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -253,6 +253,14 @@ class ShapeInference {
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers);
+ // Helper that infers the shape of the tensor produced by a gather operation
+ // with the given input shape, gather indices shape and gather dimension
+ // numbers.
+ static StatusOr<Shape> InferGatherShape(
+ const Shape& input_shape, const Shape& gather_indices_shape,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
private:
// Helper that infers the shape produced by performing an element-wise binary
// operation with the given LHS and RHS shapes.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 026c021165..7eb120843f 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -18,15 +18,16 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
+using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -1527,5 +1528,341 @@ TEST_F(ShapeInferenceTest, BadSlice) {
<< statusor.status();
}
+class GatherShapeInferenceTest : public ShapeInferenceTest {
+ protected:
+ const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
+ const Shape s64_4d_tensor_10_9_8_7_1_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
+ const Shape s64_4d_tensor_10_9_8_7_5_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ const Shape f32_5d_tensor_50_49_48_47_46_ =
+ ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
+ {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
+};
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{1},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ tuple_shape_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for input"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, tuple_shape_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for gather indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, s32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must at least of rank 1"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must be an integral tensor"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 8, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must be ascending"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 99, 100, 101},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window index 2 in gather op is out of bounds"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{4},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("All components of the window index in a gather op must either "
+ "be a output window index or explicitly elided"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 19},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ "range is [0, 5), got: 19"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 3},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "There must be exactly as many elements in "
+ "gather_dims_to_operand_dims "
+ "as there are elements in the last dimension of %gather_indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
+ "[0, 5), got: 4->7"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{2, 1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{1, 1, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("elided_window_dims in gather op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{2},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 1, 300, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window bound at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Gather op must have one window bound for every input dimension"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26, 20});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op can only elide window indices with bound 1, "
+ "but bound is 29 for index 1 at position 0"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index d42cb6cdf3..4a55e4095a 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -315,6 +315,36 @@ StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction(
+ const GatherRequest& gather_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* input_request,
+ LookUpRequest(gather_request.input()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request,
+ LookUpRequest(gather_request.gather_indices()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape shape,
+ ShapeInference::InferGatherShape(
+ input_request->output_shape(), gather_indices_request->output_shape(),
+ gather_request.dimension_numbers(),
+ AsInt64Slice(gather_request.window_bounds())));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_gather_request() = gather_request;
+
+ VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << gather_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
const GetTupleElementRequest& get_tuple_element_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -2018,6 +2048,16 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kGatherRequest: {
+ PureFunctionalVisitor(session_computation,
+ request.request().gather_request().input(),
+ num_parameters, visited, is_functional);
+ PureFunctionalVisitor(session_computation,
+ request.request().gather_request().gather_indices(),
+ num_parameters, visited, is_functional);
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
@@ -2720,6 +2760,13 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kGatherRequest: {
+ const GatherRequest& gather_request = request.request().gather_request();
+ apply(gather_request.input());
+ apply(gather_request.gather_indices());
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
@@ -3453,6 +3500,20 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kGatherRequest: {
+ const GatherRequest& gather_request = request.request().gather_request();
+ HloInstruction* input_operand =
+ lookup_instruction(gather_request.input());
+ HloInstruction* gather_indices_operand =
+ lookup_instruction(gather_request.gather_indices());
+ std::vector<int64> window_bounds;
+ c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds));
+ hlo_instruction = add_instruction(HloInstruction::CreateGather(
+ request.output_shape(), input_operand, gather_indices_operand,
+ gather_request.dimension_numbers(), window_bounds));
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 81a72583f7..fd5a2ace9b 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -242,6 +242,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddRecvInstruction(
const RecvRequest& recv_request);
+ // Enqueues a Gather instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddGatherInstruction(
+ const GatherRequest& gather_request);
+
// Returns the user-provided name of this user computation, which is provided
// via the XLA computation-building API.
const string& name() const { return name_; }