diff options
author | Michael Kuperstein <mkuper@google.com> | 2018-10-09 19:41:35 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 19:46:05 -0700 |
commit | 58fcfc98cd59ae3952399fc55380b8733df08df9 (patch) | |
tree | 24d5ac5d6691e73c227f5afa5ef68ba2ecba4ec0 | |
parent | 93eef55c4d04af24a6c8080f34629db179634f07 (diff) |
[XLA] Add documentation and HLO-level support for multi-value sort.
No support in any of the backends, and not yet exposed through XlaBuilder.
PiperOrigin-RevId: 216465753
15 files changed, 104 insertions, 64 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 86d9dbea90..ca71f2cc12 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2209,7 +2209,7 @@ Status AlgebraicSimplifierVisitor::HandleSort(HloInstruction* sort) { } // If it is key/value sort, the output of sort is a tuple. return ReplaceWithNewInstruction( - sort, HloInstruction::CreateTuple({operand, sort->mutable_operand(1)})); + sort, HloInstruction::CreateTuple(sort->operands())); } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 2047f894b4..42d1f337dc 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -2133,16 +2133,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) { Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0}); auto keys = builder.AddInstruction( HloInstruction::CreateParameter(0, keys_shape, "keys")); - auto values = builder.AddInstruction( - HloInstruction::CreateParameter(1, values_shape, "values")); + auto values0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, values_shape, "values0")); + auto values1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, values_shape, "values1")); builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0, + keys, {values0, values1})); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); ASSERT_TRUE(simplifier.Run(module).ValueOrDie()); - EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values)); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(keys, values0, values1)); } // Used for TEST_Ps that test merging (or not) of a kPad instruction into a diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index cef0eba14e..2411fdcb20 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -284,7 +284,7 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleSort) { HloInstruction::CreateParameter(1, s32_shape, "value")); HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, value)); + ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value})); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0)); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index d27786d160..909853106d 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2346,7 +2346,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 09bcf8a9e7..c317e9e3b4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -195,17 +195,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } break; case HloOpcode::kSort: { - TF_RET_CHECK(proto.operand_ids_size() == 1 || - proto.operand_ids_size() == 2) - << "Sort instruction should have 1 or 2 operands but has " + TF_RET_CHECK(proto.operand_ids_size() >= 1) + << "Sort instruction should have at least 1 operand but has " << proto.operand_ids_size(); TF_RET_CHECK(proto.dimensions().size() == 1) << "Sort instruction should have 1 dimension"; - HloInstruction* keys = operands(0); - HloInstruction* values = - proto.operand_ids_size() == 2 ? operands(1) : nullptr; - instruction = - CreateSort(proto.shape(), proto.dimensions(0), keys, values); + auto sort_operands = all_operands(); + HloInstruction* keys = sort_operands[0]; + instruction = CreateSort( + proto.shape(), proto.dimensions(0), keys, + absl::Span<HloInstruction* const>(sort_operands).subspan(1)); break; } case HloOpcode::kTranspose: @@ -1078,7 +1077,7 @@ HloInstruction::CreateBroadcastSequence( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) { + absl::Span<HloInstruction* const> values) { return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 44f776ebac..93ff04b1e4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -670,10 +670,10 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, absl::Span<const int64> dimensions); - // Creates a sort op, with a keys operand, and an optional values operand. + // Creates a sort op, with a keys operand, and optional values operands. static std::unique_ptr<HloInstruction> CreateSort( const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span<HloInstruction* const> values = {}); // Creates a while instruction, given a condition computation, a body // computation, and the initial value for the input of the computations. For diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 2ec233eaec..179ace2cdb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -600,11 +600,11 @@ std::unique_ptr<HloInstruction> HloReduceInstruction::CloneWithNewOperandsImpl( HloSortInstruction::HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values) + absl::Span<HloInstruction* const> values) : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) { AppendOperand(keys); - if (values) { - AppendOperand(values); + for (auto* value : values) { + AppendOperand(value); } } @@ -633,9 +633,8 @@ std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const { HloInstruction* keys = new_operands[0]; - HloInstruction* values = new_operands.size() == 2 ? new_operands[1] : nullptr; return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys, - values); + new_operands.subspan(1)); } HloTransposeInstruction::HloTransposeInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 4c5fc759a3..3a0b7490dc 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -418,7 +418,7 @@ class HloSortInstruction : public HloInstruction { public: explicit HloSortInstruction(const Shape& shape, int64 dimension, HloInstruction* keys, - HloInstruction* values = nullptr); + absl::Span<HloInstruction* const> values = {}); // Returns the dimension sizes or numbers associated with this instruction. const std::vector<int64>& dimensions() const override { return dimensions_; } int64 dimensions(int64 index) const override { return dimensions()[index]; } diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 96f9ff6654..128113f7a5 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -839,8 +839,6 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, break; } case HloOpcode::kSort: { - auto loc = lexer_.GetLoc(); - optional<std::vector<tensorflow::int64>> dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, &dimensions}; @@ -848,20 +846,10 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, dimensions->size() != 1) { return false; } - switch (operands.size()) { - case 1: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), /*keys=*/operands[0])); - break; - case 2: - instruction = builder->AddInstruction(HloInstruction::CreateSort( - shape, dimensions->at(0), - /*keys=*/operands[0], /*values=*/operands[1])); - break; - default: - return Error(loc, StrCat("expects either 1 or 2 operands, but has ", - operands.size(), " operands")); - } + instruction = builder->AddInstruction(HloInstruction::CreateSort( + shape, dimensions->at(0), + /*keys=*/operands[0], + /*values=*/absl::Span<HloInstruction* const>(operands).subspan(1))); break; } case HloOpcode::kTuple: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 17538c05bc..ef2e74588c 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1005,6 +1005,21 @@ ENTRY Sort { )" }, +// Sort (Key, Value, Value, Value) +{ +"SortManyValues", +R"(HloModule sort + +ENTRY Sort { + keys = f32[1024,16]{0,1} parameter(0) + values.0 = s32[1024,16]{0,1} parameter(1) + values.1 = u32[1024,16]{0,1} parameter(2) + values.2 = f32[1024,16]{0,1} parameter(3) + ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0} +} + +)" +}, // Conditional { "Conditional", diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 620458855f..a1f668921d 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -266,18 +266,20 @@ Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { } Status ShapeVerifier::HandleSort(HloInstruction* sort) { - if (sort->operand_count() < 1 || sort->operand_count() > 2) { - return InternalError("Expected 1 or 2 operands for %s instruction: %s", + if (sort->operand_count() < 1) { + return InternalError("Expected at least 1 operand for %s instruction: %s", HloOpcodeString(sort->opcode()), sort->ToString()); } - if (sort->operand_count() == 2 && - !ShapeUtil::SameDimensions(sort->operand(0)->shape(), - sort->operand(1)->shape())) { - return InternalError( - "Expected sort to have to have the same dimensions for the keys and " - "the values. Keys shape is: %s\n, Values shape is: %s", - StringifyShape(sort->operand(0)->shape()), - StringifyShape(sort->operand(1)->shape())); + for (int64 operand = 1; operand < sort->operand_count(); ++operand) { + if (!ShapeUtil::SameDimensions(sort->operand(0)->shape(), + sort->operand(operand)->shape())) { + return InternalError( + "Expected sort to have to have the same dimensions for the keys " + "and the values. Keys shape is: %s\n, Values shape (operand index " + "%lld) is: %s", + StringifyShape(sort->operand(0)->shape()), operand, + StringifyShape(sort->operand(operand)->shape())); + } } return CheckVariadicShape(sort); } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index e379911462..aa49f98bcf 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1029,17 +1029,22 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, case HloOpcode::kSort: { if (operand_shapes.size() == 1) { return *operand_shapes[0]; - } else if (operand_shapes.size() == 2) { - if (!ShapeUtil::SameDimensions(*operand_shapes[0], - *operand_shapes[1])) { - return InvalidArgument( - "Sort keys and values dimensions must match. " - "Keys shape is: %s\n, Values shape is: %s", - ShapeUtil::HumanString(*operand_shapes[0]), - ShapeUtil::HumanString(*operand_shapes[1])); + } else { + for (int64 operand = 1; operand < operand_shapes.size(); ++operand) { + if (!ShapeUtil::SameDimensions(*operand_shapes[0], + *operand_shapes[operand])) { + return InvalidArgument( + "Sort keys and values dimensions must match. " + "Keys shape is: %s\n, Values shape (operand index %lld) is: %s", + ShapeUtil::HumanString(*operand_shapes[0]), operand, + ShapeUtil::HumanString(*operand_shapes[operand])); + } + } + std::vector<Shape> operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); } - return ShapeUtil::MakeTupleShape( - {*operand_shapes[0], *operand_shapes[1]}); + return ShapeUtil::MakeTupleShape(operand_shape_values); } return InvalidArgument("Unexpected number of operands for sort"); } diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 864ed43118..7b65e8c1c9 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1618,13 +1618,37 @@ TEST_F(ShapeInferenceTest, BadSort) { auto values = ShapeUtil::MakeShape(F32, {5}); StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values}); - ASSERT_FALSE(statusor.ok()); + EXPECT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("dimensions must match")) + << statusor.status(); +} +TEST_F(ShapeInferenceTest, BadSortValuesMismatch) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_good = ShapeUtil::MakeShape(F32, {4}); + auto values_bad = ShapeUtil::MakeShape(F32, {5}); + StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_good, &values_bad}); + EXPECT_FALSE(statusor.ok()); EXPECT_THAT(statusor.status().error_message(), HasSubstr("dimensions must match")) << statusor.status(); } +TEST_F(ShapeInferenceTest, SortManyValues) { + auto keys = ShapeUtil::MakeShape(F32, {4}); + auto values_s32 = ShapeUtil::MakeShape(S32, {4}); + auto values_u32 = ShapeUtil::MakeShape(U32, {4}); + StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape( + HloOpcode::kSort, {&keys, &values_s32, &values_u32}); + EXPECT_IS_OK(statusor); + Shape inferred_shape = statusor.ValueOrDie(); + EXPECT_TRUE(ShapeUtil::Compatible( + inferred_shape, + ShapeUtil::MakeTupleShape({keys, values_s32, values_u32}))); +} + class ScatterGatherShapeInferenceTest : public ShapeInferenceTest { protected: const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index a571bd571b..d9ebebf74e 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -1073,7 +1073,8 @@ TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) { auto values = builder.AddInstruction( HloInstruction::CreateParameter(1, values_shape, "values")); auto sort = builder.AddInstruction(HloInstruction::CreateSort( - ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values)); + ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, + {values})); BuildModuleAndRunAnalysis(builder.Build()); diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 5155f0c652..2f18036ff4 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -272,9 +272,11 @@ std::vector<HloInstruction*> FindConstrainedUses( constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), converted_uses.end()); } else if (opcode == HloOpcode::kSort && - instruction->operand_count() == 2 && op_num == 0) { + instruction->operand_count() >= 2 && op_num == 0) { // Operand 0 of sort is the array of keys used for key/value - // (two-operand) kSort instructions. + // (two-operand) kSort instructions. Since sort stability is not + // guaranteed, constrain keys of key-value sort not to have duplicates, + // since otherwise the value order may legitimately differ. constrained_uses.push_back(instruction); } } |