aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc102
1 files changed, 102 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 92bbcbd740..ddf0a513c0 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -74,6 +74,26 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
+// Test that A * 0 is simplified to 0
+TEST_F(AlgebraicSimplifierTest, MulZero) {
+ Shape r0s32 = ShapeUtil::MakeShape(S32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0s32, "param0"));
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0s32, HloOpcode::kMultiply, param0, zero));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kMultiply);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+ EXPECT_EQ(computation->root_instruction(), zero);
+}
+
// Test that Reduce(Reduce(A)) -> Reduce(A)
TEST_F(AlgebraicSimplifierTest, TwoReducesToOne) {
HloComputation::Builder builder(TestName());
@@ -1230,6 +1250,55 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) {
op::Concatenate(param0, param0, param1));
}
+// Test that reduce of concat is simplified.
+TEST_F(AlgebraicSimplifierTest, SimplifyReduceOfConcat) {
+ const int kParamLength = 100;
+ Shape r3f32 =
+ ShapeUtil::MakeShape(F32, {kParamLength, kParamLength, kParamLength});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r3f32, "param0"));
+ HloInstruction* param1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, r3f32, "param1"));
+ HloInstruction* param2 = builder.AddInstruction(
+ HloInstruction::CreateParameter(2, r3f32, "param2"));
+ Shape concat_shape =
+ ShapeUtil::MakeShape(F32, {kParamLength, 3 * kParamLength, kParamLength});
+ HloInstruction* Concatenate =
+ builder.AddInstruction(HloInstruction::CreateConcatenate(
+ concat_shape, {param0, param1, param2}, 1));
+ HloComputation* add_computation = nullptr;
+ {
+ HloComputation::Builder builder(TestName() + ".add");
+ const Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ HloInstruction* p0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "p0"));
+ HloInstruction* p1 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "p1"));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
+ add_computation = module().AddEmbeddedComputation(builder.Build());
+ }
+ Shape r4f32 = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
+ Shape reduce_shape = ShapeUtil::MakeShape(F32, {kParamLength});
+
+ HloInstruction* zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
+ builder.AddInstruction(HloInstruction::CreateReduce(
+ reduce_shape, Concatenate, zero, {1, 2}, add_computation));
+
+ auto computation = module().AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie());
+
+ EXPECT_THAT(
+ computation->root_instruction(),
+ op::Map(op::Map(op::Reduce(param0, zero), op::Reduce(param1, zero)),
+ op::Reduce(param2, zero)));
+}
+
// Test a concatenate with only empty operands is removed.
TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) {
const int kParamLength = 100;
@@ -1839,6 +1908,39 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
EXPECT_THAT(computation->root_instruction(), param);
}
+TEST_F(AlgebraicSimplifierTest, SliceOfSliceToSlice) {
+ HloComputation::Builder builder(TestName());
+ const int64 dim0 = 11;
+ const int64 dim1 = 12;
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {dim0, dim1}), "param"));
+ HloInstruction* original_slice =
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0 - 2, dim1 - 4}), param,
+ /*start_indices=*/{1, 2},
+ /*limit_indices=*/{dim0 - 1, dim1 - 2}, /*strides=*/{1, 1}));
+
+ builder.AddInstruction(HloInstruction::CreateSlice(
+ ShapeUtil::MakeShape(F32, {dim0 - 5, dim1 - 9}), original_slice,
+ /*start_indices=*/{2, 3},
+ /*limit_indices=*/{dim0 - 3, dim1 - 6}, /*strides=*/{1, 1}));
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_THAT(computation->root_instruction(), op::Slice(op::Slice(param)));
+
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
+
+ EXPECT_THAT(computation->root_instruction(), op::Slice(param));
+ EXPECT_EQ(computation->root_instruction()->slice_starts(0), 3);
+ EXPECT_EQ(computation->root_instruction()->slice_starts(1), 5);
+ EXPECT_EQ(computation->root_instruction()->slice_limits(0), dim0 - 2);
+ EXPECT_EQ(computation->root_instruction()->slice_limits(1), dim1 - 4);
+}
+
TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
struct ConvTestOptions {
int in_batch = 10;