aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc74
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc193
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h8
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc341
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc61
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
22 files changed, 779 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9f5f2f96b7..0ceb9ca9e0 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -145,7 +145,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ "//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 5b09e4931e..56723e7650 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -214,6 +214,7 @@ class DfsHloVisitorBase {
virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0;
virtual Status HandleWhile(HloInstructionPtr hlo) = 0;
virtual Status HandleConditional(HloInstructionPtr hlo) = 0;
+ virtual Status HandleGather(HloInstructionPtr hlo) = 0;
virtual Status HandlePad(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index ffc4f3bb79..ecda5288ee 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -188,6 +188,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleSendDone(HloInstructionPtr send_done) override {
return DefaultAction(send_done);
}
+ Status HandleGather(HloInstructionPtr gather) override {
+ return DefaultAction(gather);
+ }
// Invoked to inform the visitor that the traversal has completed, and that
// the root was "root".
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index aa2a0a9800..30c88c0a5d 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2064,6 +2064,11 @@ GetHloBufferSlices(const HloInstruction* hlo,
return slices;
}
+Status IrEmitterUnnested::HandleGather(HloInstruction* gather) {
+ // TODO(b/72710576): Gather is not implemented on GPUs
+ return Unimplemented("Gather is not implemented on GPUs.");
+}
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildKernelThunk(
const HloInstruction* inst) {
const BufferAssignment& buffer_assn =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index 688760efbd..b83a2337e2 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -67,6 +67,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleDot(HloInstruction* dot) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleFusion(HloInstruction* fusion) override;
+ Status HandleGather(HloInstruction* gather) override;
Status HandleGetTupleElement(HloInstruction* get_tuple_element) override;
Status HandleReduce(HloInstruction* reduce) override;
Status HandleSelectAndScatter(HloInstruction* instruction) override;
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 36db711c6c..a43785b4a9 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -129,6 +129,10 @@ message HloInstructionProto {
// FFT length.
repeated int64 fft_length = 32;
+
+ // Gather dimension numbers.
+ xla.GatherDimensionNumbers gather_dimension_numbers = 33;
+ repeated int64 gather_window_bounds = 34;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 6a4651d83f..4ec2ef27bf 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -533,6 +533,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) {
return Status::OK();
}
+Status HloCostAnalysis::HandleGather(const HloInstruction* gather) {
+ // Gather does not issue any flops.
+ return Status::OK();
+}
+
Status HloCostAnalysis::FinishVisit(const HloInstruction*) {
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index af52ea06ca..d17678d20f 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -100,6 +100,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleTranspose(const HloInstruction* transpose) override;
Status HandleWhile(const HloInstruction* xla_while) override;
Status HandleConditional(const HloInstruction* conditional) override;
+ Status HandleGather(const HloInstruction* gather) override;
Status FinishVisit(const HloInstruction* root) override;
Status Preprocess(const HloInstruction* hlo) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 9b0e2fd7d6..2861fec39e 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -940,6 +940,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kConcatenate:
case HloOpcode::kCopy:
case HloOpcode::kDynamicSlice:
+ case HloOpcode::kGather:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 0d925ad00d..b7dd055d7c 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1155,6 +1155,38 @@ bool HloInstruction::HasSideEffect() const {
return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGather(
+ const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ std::unique_ptr<HloInstruction> instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kGather, shape));
+ instruction->AppendOperand(operand);
+ instruction->AppendOperand(gather_indices);
+ instruction->gather_dimension_numbers_ =
+ MakeUnique<GatherDimensionNumbers>(gather_dim_numbers);
+ c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_));
+ return instruction;
+}
+
+/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims) {
+ GatherDimensionNumbers gather_dim_numbers;
+ for (int64 output_window_dim : output_window_dims) {
+ gather_dim_numbers.add_output_window_dims(output_window_dim);
+ }
+ for (int64 elided_window_dim : elided_window_dims) {
+ gather_dim_numbers.add_elided_window_dims(elided_window_dim);
+ }
+ for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) {
+ gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim);
+ }
+
+ return gather_dim_numbers;
+}
+
std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
@@ -1397,6 +1429,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
CHECK_EQ(new_operands.size(), 1);
clone = CreateRecvDone(new_operands[0]);
break;
+ case HloOpcode::kGather:
+ CHECK_EQ(new_operands.size(), 2);
+ clone = CreateGather(shape, new_operands[0], new_operands[1],
+ *gather_dimension_numbers_, gather_window_bounds_);
+ break;
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1740,6 +1777,11 @@ bool HloInstruction::IdenticalSlowPath(
return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
other.dot_dimension_numbers());
+ case HloOpcode::kGather:
+ return protobuf_util::ProtobufEquals(gather_dimension_numbers(),
+ other.gather_dimension_numbers()) &&
+ gather_window_bounds() == other.gather_window_bounds();
+
// FFT has various types & lengths.
case HloOpcode::kFft:
return fft_type() == other.fft_type() &&
@@ -2171,6 +2213,11 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
if (dot_dimension_numbers_ != nullptr) {
extra.push_back(DotDimensionNumbersToString());
}
+ if (gather_dimension_numbers_ != nullptr) {
+ extra.push_back(GatherDimensionNumbersToString());
+ extra.push_back(
+ StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}"));
+ }
if (opcode() == HloOpcode::kFft) {
extra.push_back(StrCat("fft_type=", FftType_Name(fft_type())));
extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}"));
@@ -2302,6 +2349,14 @@ HloInstructionProto HloInstruction::ToProto() const {
if (dot_dimension_numbers_ != nullptr) {
*proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
}
+ if (gather_dimension_numbers_ != nullptr) {
+ *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_;
+ }
+ if (opcode() == HloOpcode::kGather) {
+ for (int64 bound : gather_window_bounds()) {
+ proto.add_gather_window_bounds(bound);
+ }
+ }
for (int i = 0; i < slice_starts_.size(); ++i) {
auto* slice_dimension = proto.add_slice_dimensions();
slice_dimension->set_start(slice_starts_[i]);
@@ -2618,6 +2673,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleSend(this);
case HloOpcode::kSendDone:
return visitor->HandleSendDone(this);
+ case HloOpcode::kGather:
+ return visitor->HandleGather(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -3301,6 +3358,23 @@ string HloInstruction::DotDimensionNumbersToString() const {
return Join(result, ", ");
}
+string HloInstruction::GatherDimensionNumbersToString() const {
+ CHECK_NE(gather_dimension_numbers_.get(), nullptr);
+ string output_window_dims =
+ StrCat("output_window_dims={",
+ Join(gather_dimension_numbers_->output_window_dims(), ","), "}");
+ string elided_window_dims =
+ StrCat("elided_window_dims={",
+ Join(gather_dimension_numbers_->elided_window_dims(), ","), "}");
+ string gather_dims_to_operand_dims = StrCat(
+ "gather_dims_to_operand_dims={",
+ Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}");
+
+ return Join<std::initializer_list<string>>(
+ {output_window_dims, elided_window_dims, gather_dims_to_operand_dims},
+ ", ");
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index e898a83739..1762d227be 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -451,6 +451,12 @@ class HloInstruction {
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation);
+ static std::unique_ptr<HloInstruction> CreateGather(
+ const Shape& shape, HloInstruction* operand,
+ HloInstruction* gather_indices,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
// Creates a fusion instruction. A fusion instruction contains one or more
// fused instructions forming an expression with a single root
// "fused_root". Additional instructions can be added to the fusion
@@ -492,6 +498,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
+ // Creates an instance of GatherDimensionNumbers.
+ static GatherDimensionNumbers MakeGatherDimNumbers(
+ tensorflow::gtl::ArraySlice<int64> output_window_dims,
+ tensorflow::gtl::ArraySlice<int64> elided_window_dims,
+ tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims);
+
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
@@ -1102,6 +1114,19 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
+ const GatherDimensionNumbers& gather_dimension_numbers() const {
+ CHECK(gather_dimension_numbers_ != nullptr);
+ return *gather_dimension_numbers_;
+ }
+
+ tensorflow::gtl::ArraySlice<int64> gather_window_bounds() const {
+ CHECK_EQ(opcode(), HloOpcode::kGather);
+ return gather_window_bounds_;
+ }
+
+ // Returns the dump string of the gather dimension numbers.
+ string GatherDimensionNumbersToString() const;
+
// Returns the random distribution for this rng node.
//
// Precondition: opcode() == HloOpcode::kRng
@@ -1366,6 +1391,9 @@ class HloInstruction {
// Describes the dimension numbers used for a dot.
std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
+ std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
+ std::vector<int64> gather_window_bounds_;
+
// Describes FFT type for an FFT instruction.
FftType fft_type_ = FftType::FFT;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 94e9bfe56e..32d3ed272b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1271,5 +1271,40 @@ TEST_F(HloInstructionTest, Stringification) {
"true_computation=%TransposeDot, false_computation=%TransposeDot");
}
+TEST_F(HloInstructionTest, StringifyGather) {
+ Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ Shape gather_indices_tensor_shape =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ Shape gather_result_shape =
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26});
+
+ HloComputation::Builder builder("Gather");
+ HloInstruction* input = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor"));
+ HloInstruction* gather_indices =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, gather_indices_tensor_shape, "gather_indices"));
+
+ HloInstruction* gather_instruction =
+ builder.AddInstruction(HloInstruction::CreateGather(
+ gather_result_shape, input, gather_indices,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+
+ HloModule module(TestName());
+ module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(gather_instruction->ToString(),
+ "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} "
+ "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, "
+ "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), "
+ "output_window_dims={4,5,6,7,8}, elided_window_dims={}, "
+ "gather_dims_to_operand_dims={0,1,2,3,4}, "
+ "window_bounds={30,29,28,27,26}");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 088dd15dbf..af24604c39 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -76,6 +76,7 @@ namespace xla {
V(kFft, "fft") \
V(kFloor, "floor") \
V(kFusion, "fusion", kHloOpcodeIsVariadic) \
+ V(kGather, "gather") \
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index f3378309c2..b1fd068115 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -424,6 +424,14 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
} // namespace
+Status ShapeVerifier::HandleGather(HloInstruction* gather) {
+ return CheckShape(
+ gather,
+ ShapeInference::InferGatherShape(
+ gather->operand(0)->shape(), gather->operand(1)->shape(),
+ gather->gather_dimension_numbers(), gather->gather_window_bounds()));
+}
+
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index f9f898c236..1dd7ec3c51 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -80,6 +80,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleBatchNormInference(
HloInstruction* batch_norm_inference) override;
Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override;
+ Status HandleGather(HloInstruction* gather) override;
Status FinishVisit(HloInstruction*) override {
return tensorflow::Status::OK();
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index f08d809d79..f494748e17 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -102,6 +102,7 @@ namespace xla {
case HloOpcode::kExp:
case HloOpcode::kFft:
case HloOpcode::kFusion:
+ case HloOpcode::kGather:
case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kMap:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 95c853b5c4..e278eab690 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1446,6 +1446,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
case OpRequest::kFftRequest:
handle_status = computation->AddFftInstruction(arg->fft_request());
break;
+ case OpRequest::kGatherRequest:
+ handle_status = computation->AddGatherInstruction(arg->gather_request());
+ break;
case OpRequest::kGetTupleElementRequest:
handle_status = computation->AddGetTupleElementInstruction(
arg->get_tuple_element_request());
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 004889b5f2..c9692757b2 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -2448,4 +2448,197 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return to_apply.result();
}
+static Status ValidateGatherDimensionNumbers(
+ const Shape& input_shape,
+ tensorflow::gtl::ArraySlice<int64> gather_indices_shape,
+ const GatherDimensionNumbers& dim_numbers) {
+ if (!c_is_sorted(dim_numbers.output_window_dims())) {
+ return InvalidArgument(
+ "Output window dimensions in gather op must be ascending; got: %s",
+ Join(dim_numbers.output_window_dims(), ", ").c_str());
+ }
+
+ if (c_adjacent_find(dim_numbers.output_window_dims()) !=
+ dim_numbers.output_window_dims().end()) {
+ return InvalidArgument(
+ "Output window dimensions in gather op must not repeat; got: %s",
+ Join(dim_numbers.output_window_dims(), ", ").c_str());
+ }
+
+ const int64 output_window_dim_count = dim_numbers.output_window_dims_size();
+ const int64 output_shape_rank =
+ output_window_dim_count + gather_indices_shape.size();
+
+ for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) {
+ int64 window_index = dim_numbers.output_window_dims(i);
+ if (window_index < 0 || window_index >= output_shape_rank) {
+ return InvalidArgument(
+ "Window index %d in gather op is out of bounds; got %lld, but should "
+ "have been in"
+ "[0,%lld)",
+ i, window_index, output_shape_rank);
+ }
+ }
+
+ if (dim_numbers.gather_dims_to_operand_dims_size() !=
+ gather_indices_shape.back()) {
+ return InvalidArgument(
+ "There must be exactly as many elements in gather_dims_to_operand_dims "
+ "as there are elements in the last dimension of %%gather_indices; got: "
+ "%d, expected %lld",
+ dim_numbers.gather_dims_to_operand_dims_size(),
+ gather_indices_shape.back());
+ }
+
+ for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) {
+ int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i);
+ if (gather_dim_to_input_dim < 0 ||
+ gather_dim_to_input_dim >= input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), "
+ "got: %d->%lld",
+ input_shape.dimensions_size(), i, gather_dim_to_input_dim);
+ }
+ }
+
+ std::vector<int64> sorted_gather_dims_to_operand_dims(
+ dim_numbers.gather_dims_to_operand_dims().begin(),
+ dim_numbers.gather_dims_to_operand_dims().end());
+
+ c_sort(sorted_gather_dims_to_operand_dims);
+
+ if (c_adjacent_find(sorted_gather_dims_to_operand_dims) !=
+ sorted_gather_dims_to_operand_dims.end()) {
+ return InvalidArgument(
+ "Repeated dimensions are not allowed in gather_dims_to_operand_dims; "
+ "got: %s",
+ Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str());
+ }
+
+ for (int64 elided_dim : dim_numbers.elided_window_dims()) {
+ if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Invalid elided_window_dims set in gather op; valid range is [0, "
+ "%d), got: %lld",
+ input_shape.dimensions_size(), elided_dim);
+ }
+ }
+
+ if (!c_is_sorted(dim_numbers.elided_window_dims())) {
+ return InvalidArgument(
+ "elided_window_dims in gather op must be sorted; got: %s",
+ Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ }
+
+ if (c_adjacent_find(dim_numbers.elided_window_dims()) !=
+ dim_numbers.elided_window_dims().end()) {
+ return InvalidArgument(
+ "Repeated dimensions not allowed in elided_window_dims in gather op; "
+ "got: %s",
+ Join(dim_numbers.elided_window_dims(), ", ").c_str());
+ }
+
+ return Status::OK();
+}
+
+/*static*/ StatusOr<Shape> ShapeInference::InferGatherShape(
+ const Shape& input_shape, const Shape& gather_indices_shape,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds) {
+ TF_RETURN_IF_ERROR(
+ ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op"));
+ TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(
+ gather_indices_shape, "gather indices operand of gather op"));
+
+ if (gather_indices_shape.dimensions_size() < 1) {
+ return InvalidArgument(
+ "Gather indices parameter must at least of rank 1; got %s",
+ ShapeUtil::HumanString(gather_indices_shape).c_str());
+ }
+
+ if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) {
+ return InvalidArgument(
+ "Gather indices parameter must be an integral tensor; got %s",
+ ShapeUtil::HumanString(gather_indices_shape).c_str());
+ }
+
+ std::vector<int64> expanded_gather_indices_shape;
+ // We implicitly reshape gather indices of shape P[N] to P[N,1].
+ expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size());
+ c_copy(gather_indices_shape.dimensions(),
+ std::back_inserter(expanded_gather_indices_shape));
+ if (expanded_gather_indices_shape.size() == 1) {
+ expanded_gather_indices_shape.push_back(1);
+ }
+
+ TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers(
+ input_shape, expanded_gather_indices_shape, gather_dim_numbers));
+
+ if (window_bounds.size() != input_shape.dimensions_size()) {
+ return InvalidArgument(
+ "Gather op must have one window bound for every input dimension; got: "
+ "len(window_bounds)=%lu, input_shape.rank=%d",
+ window_bounds.size(), input_shape.dimensions_size());
+ }
+
+ if (window_bounds.size() !=
+ gather_dim_numbers.output_window_dims_size() +
+ gather_dim_numbers.elided_window_dims_size()) {
+ return InvalidArgument(
+ "All components of the window index in a gather op must either be a "
+ "output window index or explicitly elided; got len(window_bounds)=%lu, "
+ "output_window_bounds=%s, elided_window_bounds=%s",
+ window_bounds.size(),
+ Join(gather_dim_numbers.output_window_dims(), ",").c_str(),
+ Join(gather_dim_numbers.elided_window_dims(), ",").c_str());
+ }
+
+ for (int i = 0; i < window_bounds.size(); i++) {
+ int64 window_bound = window_bounds[i];
+ int64 corresponding_input_bound = input_shape.dimensions(i);
+ if (window_bound < 0 || window_bound > corresponding_input_bound) {
+ return InvalidArgument(
+ "Window bound at index %d in gather op is out of range, must be "
+ "within "
+ "[0, %lld), got %lld",
+ i, corresponding_input_bound + 1, window_bound);
+ }
+ }
+
+ for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) {
+ if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) {
+ return InvalidArgument(
+ "Gather op can only elide window indices with bound 1, but bound is "
+ "%lld for index %lld at position %d",
+ window_bounds[gather_dim_numbers.elided_window_dims(i)],
+ gather_dim_numbers.elided_window_dims(i), i);
+ }
+ }
+
+ int64 result_rank = gather_dim_numbers.output_window_dims_size() +
+ (expanded_gather_indices_shape.size() - 1);
+ int64 window_dims_seen = 0;
+ int64 gather_dims_seen = 0;
+ std::vector<int64> output_dim_bounds;
+ output_dim_bounds.reserve(result_rank);
+ for (int64 i = 0; i < result_rank; i++) {
+ int64 current_bound;
+ bool is_window_index =
+ c_binary_search(gather_dim_numbers.output_window_dims(), i);
+ if (is_window_index) {
+ while (c_binary_search(gather_dim_numbers.elided_window_dims(),
+ window_dims_seen)) {
+ window_dims_seen++;
+ }
+ current_bound = window_bounds[window_dims_seen++];
+ } else {
+ current_bound = expanded_gather_indices_shape[gather_dims_seen++];
+ }
+
+ output_dim_bounds.push_back(current_bound);
+ }
+
+ return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index c4a1da28f3..0d3045213d 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -253,6 +253,14 @@ class ShapeInference {
const Shape& lhs, const Shape& rhs,
const DotDimensionNumbers& dimension_numbers);
+ // Helper that infers the shape of the tensor produced by a gather operation
+ // with the given input shape, gather indices shape and gather dimension
+ // numbers.
+ static StatusOr<Shape> InferGatherShape(
+ const Shape& input_shape, const Shape& gather_indices_shape,
+ const GatherDimensionNumbers& gather_dim_numbers,
+ tensorflow::gtl::ArraySlice<int64> window_bounds);
+
private:
// Helper that infers the shape produced by performing an element-wise binary
// operation with the given LHS and RHS shapes.
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 026c021165..7eb120843f 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -18,15 +18,16 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
namespace xla {
namespace {
+using ::tensorflow::gtl::ArraySlice;
using ::testing::ContainsRegex;
using ::testing::HasSubstr;
@@ -1527,5 +1528,341 @@ TEST_F(ShapeInferenceTest, BadSlice) {
<< statusor.status();
}
+class GatherShapeInferenceTest : public ShapeInferenceTest {
+ protected:
+ const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
+ const Shape s64_4d_tensor_10_9_8_7_1_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
+ const Shape s64_4d_tensor_10_9_8_7_5_ =
+ ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
+ const Shape f32_5d_tensor_50_49_48_47_46_ =
+ ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
+ const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
+ {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
+};
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGather) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{1},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(
+ ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4},
+ /*elided_window_dims=*/{0},
+ /*gather_dims_to_operand_dims=*/{0}),
+ /*window_bounds=*/{1, 48}));
+ EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
+ TF_ASSERT_OK_AND_ASSIGN(
+ Shape gather_shape,
+ ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26}));
+ EXPECT_TRUE(ShapeUtil::Equal(
+ gather_shape,
+ ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
+ << ShapeUtil::HumanString(gather_shape);
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ tuple_shape_, s64_vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for input"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, tuple_shape_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Expected non-tuple argument for gather indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, ScalarGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, s32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must at least of rank 1"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ s64_vector_32_, vector_32_,
+ HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{1}),
+ /*window_bounds=*/{64, 1});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather indices parameter must be an integral tensor"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 8, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must be ascending"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowIndices) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 7},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Output window dimensions in gather op must not repeat"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 99, 100, 101},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window index 2 in gather op is out of bounds"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{4},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("All components of the window index in a gather op must either "
+ "be a output window index or explicitly elided"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 19},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Invalid elided_window_dims set in gather op; valid "
+ "range is [0, 5), got: 19"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{0, 1, 2, 3, 3},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions not allowed in elided_window_dims in gather op"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "There must be exactly as many elements in "
+ "gather_dims_to_operand_dims "
+ "as there are elements in the last dimension of %gather_indices"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is "
+ "[0, 5), got: 4->7"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}),
+ /*window_bounds=*/{30, 29, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Repeated dimensions are not allowed in gather_dims_to_operand_dims"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{2, 1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{1, 1, 28, 27, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("elided_window_dims in gather op must be sorted"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{2},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 1, 300, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Window bound at index 3 in gather op is out of range, "
+ "must be within [0, 48), got 300"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7, 8},
+ /*elided_window_dims=*/{},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(
+ statusor.status().error_message(),
+ HasSubstr(
+ "Gather op must have one window bound for every input dimension"))
+ << statusor.status();
+}
+
+TEST_F(GatherShapeInferenceTest,
+ InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
+ StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
+ f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
+ HloInstruction::MakeGatherDimNumbers(
+ /*output_window_dims=*/{4, 5, 6, 7},
+ /*elided_window_dims=*/{1},
+ /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}),
+ /*window_bounds=*/{30, 29, 28, 26, 20});
+ ASSERT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("Gather op can only elide window indices with bound 1, "
+ "but bound is 29 for index 1 at position 0"))
+ << statusor.status();
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index d42cb6cdf3..4a55e4095a 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -315,6 +315,36 @@ StatusOr<ComputationDataHandle> UserComputation::AddConstantInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddGatherInstruction(
+ const GatherRequest& gather_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* input_request,
+ LookUpRequest(gather_request.input()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request,
+ LookUpRequest(gather_request.gather_indices()));
+
+ TF_ASSIGN_OR_RETURN(
+ Shape shape,
+ ShapeInference::InferGatherShape(
+ input_request->output_shape(), gather_indices_request->output_shape(),
+ gather_request.dimension_numbers(),
+ AsInt64Slice(gather_request.window_bounds())));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_gather_request() = gather_request;
+
+ VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << gather_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddGetTupleElementInstruction(
const GetTupleElementRequest& get_tuple_element_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -2018,6 +2048,16 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kGatherRequest: {
+ PureFunctionalVisitor(session_computation,
+ request.request().gather_request().input(),
+ num_parameters, visited, is_functional);
+ PureFunctionalVisitor(session_computation,
+ request.request().gather_request().gather_indices(),
+ num_parameters, visited, is_functional);
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
@@ -2720,6 +2760,13 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kGatherRequest: {
+ const GatherRequest& gather_request = request.request().gather_request();
+ apply(gather_request.input());
+ apply(gather_request.gather_indices());
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
@@ -3453,6 +3500,20 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kGatherRequest: {
+ const GatherRequest& gather_request = request.request().gather_request();
+ HloInstruction* input_operand =
+ lookup_instruction(gather_request.input());
+ HloInstruction* gather_indices_operand =
+ lookup_instruction(gather_request.gather_indices());
+ std::vector<int64> window_bounds;
+ c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds));
+ hlo_instruction = add_instruction(HloInstruction::CreateGather(
+ request.output_shape(), input_operand, gather_indices_operand,
+ gather_request.dimension_numbers(), window_bounds));
+ break;
+ }
+
case OpRequest::OP_NOT_SET:
LOG(FATAL) << "OperationRequest doesn't contain a request";
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 81a72583f7..fd5a2ace9b 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -242,6 +242,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddRecvInstruction(
const RecvRequest& recv_request);
+ // Enqueues a Gather instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddGatherInstruction(
+ const GatherRequest& gather_request);
+
// Returns the user-provided name of this user computation, which is provided
// via the XLA computation-building API.
const string& name() const { return name_; }