aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-10-09 19:41:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 19:46:05 -0700
commit58fcfc98cd59ae3952399fc55380b8733df08df9 (patch)
tree24d5ac5d6691e73c227f5afa5ef68ba2ecba4ec0 /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
parent93eef55c4d04af24a6c8080f34629db179634f07 (diff)
[XLA] Add documentation and HLO-level support for multi-value sort.
No support in any of the backends, and not yet exposed through XlaBuilder. PiperOrigin-RevId: 216465753
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 2047f894b4..42d1f337dc 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2133,16 +2133,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- auto values = builder.AddInstruction(
- HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto values0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values0"));
+ auto values1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, values_shape, "values1"));
builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0,
+ keys, {values0, values1}));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
- EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
+ EXPECT_THAT(computation->root_instruction(),
+ op::Tuple(keys, values0, values1));
}
// Used for TEST_Ps that test merging (or not) of a kPad instruction into a