aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-10-09 19:41:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 19:46:05 -0700
commit58fcfc98cd59ae3952399fc55380b8733df08df9 (patch)
tree24d5ac5d6691e73c227f5afa5ef68ba2ecba4ec0
parent93eef55c4d04af24a6c8080f34629db179634f07 (diff)
[XLA] Add documentation and HLO-level support for multi-value sort.
No support in any of the backends, and not yet exposed through XlaBuilder. PiperOrigin-RevId: 216465753
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc2
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc17
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc22
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc25
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/test_utils.cc6
15 files changed, 104 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 86d9dbea90..ca71f2cc12 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -2209,7 +2209,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) {
}
// If it is key/value sort, the output of sort is a tuple.
return ReplaceWithNewInstruction(
- sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)}));
+ sort, HloInstruction::CreateTuple(sort->operands()));
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 2047f894b4..42d1f337dc 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2133,16 +2133,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- auto values = builder.AddInstruction(
- HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto values0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values0"));
+ auto values1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, values_shape, "values1"));
builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0,
+ keys, {values0, values1}));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
- EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
+ EXPECT_THAT(computation->root_instruction(),
+ op::Tuple(keys, values0, values1));
}
// Used for TEST_Ps that test merging (or not) of a kPad instruction into a
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index cef0eba14e..2411fdcb20 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -284,7 +284,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) {
HloInstruction::CreateParameter(1, s32_shape, "value"));
HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value));
+ ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value}));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index d27786d160..909853106d 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -2346,7 +2346,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
auto values = builder.AddInstruction(
HloInstruction::CreateParameter(1, values_shape, "values"));
auto sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys,
+ {values}));
BuildModuleAndRunAnalysis(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 09bcf8a9e7..c317e9e3b4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -195,17 +195,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
break;
case HloOpcode::kSort: {
- TF_RET_CHECK(proto.operand_ids_size() == 1 ||
- proto.operand_ids_size() == 2)
- << "Sort instruction should have 1 or 2 operands but has "
+ TF_RET_CHECK(proto.operand_ids_size() >= 1)
+ << "Sort instruction should have at least 1 operand but has "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.dimensions().size() == 1)
<< "Sort instruction should have 1 dimension";
- HloInstruction* keys = operands(0);
- HloInstruction* values =
- proto.operand_ids_size() == 2 ? operands(1) : nullptr;
- instruction =
- CreateSort(proto.shape(), proto.dimensions(0), keys, values);
+ auto sort_operands = all_operands();
+ HloInstruction* keys = sort_operands[0];
+ instruction = CreateSort(
+ proto.shape(), proto.dimensions(0), keys,
+ absl::Span<HloInstruction* const>(sort_operands).subspan(1));
break;
}
case HloOpcode::kTranspose:
@@ -1078,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence(
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
- HloInstruction* values) {
+ absl::Span<HloInstruction* const> values) {
return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 44f776ebac..93ff04b1e4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -670,10 +670,10 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand,
absl::Span<const int64> dimensions);
- // Creates a sort op, with a keys operand, and an optional values operand.
+ // Creates a sort op, with a keys operand, and optional values operands.
static std::unique_ptr<HloInstruction> CreateSort(
const Shape& shape, int64 dimension, HloInstruction* keys,
- HloInstruction* values = nullptr);
+ absl::Span<HloInstruction* const> values = {});
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 2ec233eaec..179ace2cdb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -600,11 +600,11 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl(
HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension,
HloInstruction* keys,
- HloInstruction* values)
+ absl::Span<HloInstruction* const> values)
: HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
AppendOperand(keys);
- if (values) {
- AppendOperand(values);
+ for (auto* value : values) {
+ AppendOperand(value);
}
}
@@ -633,9 +633,8 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
HloInstruction* keys = new_operands[0];
- HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr;
return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys,
- values);
+ new_operands.subspan(1));
}
HloTransposeInstruction::HloTransposeInstruction(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 4c5fc759a3..3a0b7490dc 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -418,7 +418,7 @@ class HloSortInstruction : public HloInstruction {
public:
explicit HloSortInstruction(const Shape& shape, int64 dimension,
HloInstruction* keys,
- HloInstruction* values = nullptr);
+ absl::Span<HloInstruction* const> values = {});
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 96f9ff6654..128113f7a5 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -839,8 +839,6 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
break;
}
case HloOpcode::kSort: {
- auto loc = lexer_.GetLoc();
-
optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
@@ -848,20 +846,10 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
dimensions->size() != 1) {
return false;
}
- switch (operands.size()) {
- case 1:
- instruction = builder->AddInstruction(HloInstruction::CreateSort(
- shape, dimensions->at(0), /*keys=*/operands[0]));
- break;
- case 2:
- instruction = builder->AddInstruction(HloInstruction::CreateSort(
- shape, dimensions->at(0),
- /*keys=*/operands[0], /*values=*/operands[1]));
- break;
- default:
- return Error(loc, StrCat("expects either 1 or 2 operands, but has ",
- operands.size(), " operands"));
- }
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0),
+ /*keys=*/operands[0],
+ /*values=*/absl::Span<HloInstruction* const>(operands).subspan(1)));
break;
}
case HloOpcode::kTuple: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 17538c05bc..ef2e74588c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1005,6 +1005,21 @@ ENTRY Sort {
)"
},
+// Sort (Key, Value, Value, Value)
+{
+"SortManyValues",
+R"(HloModule sort
+
+ENTRY Sort {
+ keys = f32[1024,16]{0,1} parameter(0)
+ values.0 = s32[1024,16]{0,1} parameter(1)
+ values.1 = u32[1024,16]{0,1} parameter(2)
+ values.2 = f32[1024,16]{0,1} parameter(3)
+ ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}
+}
+
+)"
+},
// Conditional
{
"Conditional",
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 620458855f..a1f668921d 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -266,18 +266,20 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
- if (sort->operand_count() < 1 || sort->operand_count() > 2) {
- return InternalError("Expected 1 or 2 operands for %s instruction: %s",
+ if (sort->operand_count() < 1) {
+ return InternalError("Expected at least 1 operand for %s instruction: %s",
HloOpcodeString(sort->opcode()), sort->ToString());
}
- if (sort->operand_count() == 2 &&
- !ShapeUtil::SameDimensions(sort->operand(0)->shape(),
- sort->operand(1)->shape())) {
- return InternalError(
- "Expected sort to have to have the same dimensions for the keys and "
- "the values. Keys shape is: %s\n, Values shape is: %s",
- StringifyShape(sort->operand(0)->shape()),
- StringifyShape(sort->operand(1)->shape()));
+ for (int64 operand = 1; operand < sort->operand_count(); ++operand) {
+ if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
+ sort->operand(operand)->shape())) {
+ return InternalError(
+ "Expected sort to have to have the same dimensions for the keys "
+ "and the values. Keys shape is: %s\n, Values shape (operand index "
+ "%lld) is: %s",
+ StringifyShape(sort->operand(0)->shape()), operand,
+ StringifyShape(sort->operand(operand)->shape()));
+ }
}
return CheckVariadicShape(sort);
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index e379911462..aa49f98bcf 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -1029,17 +1029,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
case HloOpcode::kSort: {
if (operand_shapes.size() == 1) {
return *operand_shapes[0];
- } else if (operand_shapes.size() == 2) {
- if (!ShapeUtil::SameDimensions(*operand_shapes[0],
- *operand_shapes[1])) {
- return InvalidArgument(
- "Sort keys and values dimensions must match. "
- "Keys shape is: %s\n, Values shape is: %s",
- ShapeUtil::HumanString(*operand_shapes[0]),
- ShapeUtil::HumanString(*operand_shapes[1]));
+ } else {
+ for (int64 operand = 1; operand < operand_shapes.size(); ++operand) {
+ if (!ShapeUtil::SameDimensions(*operand_shapes[0],
+ *operand_shapes[operand])) {
+ return InvalidArgument(
+ "Sort keys and values dimensions must match. "
+ "Keys shape is: %s\n, Values shape (operand index %lld) is: %s",
+ ShapeUtil::HumanString(*operand_shapes[0]), operand,
+ ShapeUtil::HumanString(*operand_shapes[operand]));
+ }
+ }
+ std::vector<Shape> operand_shape_values;
+ for (const Shape* operand_shape : operand_shapes) {
+ operand_shape_values.push_back(*operand_shape);
}
- return ShapeUtil::MakeTupleShape(
- {*operand_shapes[0], *operand_shapes[1]});
+ return ShapeUtil::MakeTupleShape(operand_shape_values);
}
return InvalidArgument("Unexpected number of operands for sort");
}
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 864ed43118..7b65e8c1c9 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) {
auto values = ShapeUtil::MakeShape(F32, {5});
StatusOr<Shape> statusor =
ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
- ASSERT_FALSE(statusor.ok());
+ EXPECT_FALSE(statusor.ok());
+ EXPECT_THAT(statusor.status().error_message(),
+ HasSubstr("dimensions must match"))
+ << statusor.status();
+}
+TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_good = ShapeUtil::MakeShape(F32, {4});
+ auto values_bad = ShapeUtil::MakeShape(F32, {5});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_good, &values_bad});
+ EXPECT_FALSE(statusor.ok());
EXPECT_THAT(statusor.status().error_message(),
HasSubstr("dimensions must match"))
<< statusor.status();
}
+TEST_F(ShapeInferenceTest, SortManyValues) {
+ auto keys = ShapeUtil::MakeShape(F32, {4});
+ auto values_s32 = ShapeUtil::MakeShape(S32, {4});
+ auto values_u32 = ShapeUtil::MakeShape(U32, {4});
+ StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
+ HloOpcode::kSort, {&keys, &values_s32, &values_u32});
+ EXPECT_IS_OK(statusor);
+ Shape inferred_shape = statusor.ValueOrDie();
+ EXPECT_TRUE(ShapeUtil::Compatible(
+ inferred_shape,
+ ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
+}
+
class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
protected:
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index a571bd571b..d9ebebf74e 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -1073,7 +1073,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
auto values = builder.AddInstruction(
HloInstruction::CreateParameter(1, values_shape, "values"));
auto sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys,
+ {values}));
BuildModuleAndRunAnalysis(builder.Build());
diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc
index 5155f0c652..2f18036ff4 100644
--- a/tensorflow/compiler/xla/tests/test_utils.cc
+++ b/tensorflow/compiler/xla/tests/test_utils.cc
@@ -272,9 +272,11 @@ std::vector<HloInstruction*> FindConstrainedUses(
constrained_uses.insert(constrained_uses.end(), converted_uses.begin(),
converted_uses.end());
} else if (opcode == HloOpcode::kSort &&
- instruction->operand_count() == 2 && op_num == 0) {
+ instruction->operand_count() >= 2 && op_num == 0) {
// Operand 0 of sort is the array of keys used for key/value
- // (two-operand) kSort instructions.
+ // (two-operand) kSort instructions. Since sort stability is not
+ // guaranteed, constrain keys of key-value sort not to have duplicates,
+ // since otherwise the value order may legitimately differ.
constrained_uses.push_back(instruction);
}
}