aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 2047f894b4..42d1f337dc 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2133,16 +2133,20 @@ TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- auto values = builder.AddInstruction(
- HloInstruction::CreateParameter(1, values_shape, "values"));
+ auto values0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, values_shape, "values0"));
+ auto values1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, values_shape, "values1"));
builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys, values));
+ ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0,
+ keys, {values0, values1}));
auto module = CreateNewModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
- EXPECT_THAT(computation->root_instruction(), op::Tuple(keys, values));
+ EXPECT_THAT(computation->root_instruction(),
+ op::Tuple(keys, values0, values1));
}
// Used for TEST_Ps that test merging (or not) of a kPad instruction into a