aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2017-12-20 14:13:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-20 14:17:02 -0800
commit1279bb10b9bd76f15637074c6518a3464916e007 (patch)
tree8baf3b1a58b5294f95eece3be61a7f59b4e95fb3 /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
parentbd9f9d71df99fcdaf47326b0c81d79d2b2681fee (diff)
[XLA] Add reassociation rule for adds, and add canonicalization rules for add/sub of a constant.
This patch adds a new algebraic simplification: * (A + C0) + C1 => A + (C0 + C1), where C0 and C1 are constants. This allows us to constant-fold C0 + C1. In service of this rule, this patch also adds two new canonicalizations: * Const + A => A + Const * A - Const => A + (-1 * Const) PiperOrigin-RevId: 179731747
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc71
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 175d4d8d7f..48e822e3af 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -71,6 +71,55 @@ TEST_F(AlgebraicSimplifierTest, AddZero) {
EXPECT_EQ(root, param0);
}
+// Test that Const + A is canonicalized to A + Const.
+TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Add(param0, op::Constant()));
+}
+
+// Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2.
+TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(42.0f)));
+ HloInstruction* constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0(3.14159f)));
+
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2)));
+}
+
TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) {
Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2});
HloComputation::Builder builder(TestName());
@@ -139,6 +188,28 @@ TEST_F(AlgebraicSimplifierTest, SubZero) {
EXPECT_EQ(root, param0);
}
+// Test that A - Const is canonicalized to A + (-Const).
+TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) {
+ Shape r0f32 = ShapeUtil::MakeShape(F32, {});
+ HloComputation::Builder builder(TestName());
+ HloInstruction* param0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, r0f32, "param0"));
+ HloInstruction* constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ r0f32, HloOpcode::kSubtract, param0, constant));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kSubtract);
+ AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
+ non_bitcasting_callback());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ root = computation->root_instruction();
+ EXPECT_THAT(root, op::Add(param0, op::Negate(constant)));
+}
+
// Test that (A/B)/C is simplified to A/(B*C).
TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) {
Shape r0f32 = ShapeUtil::MakeShape(F32, {});