diff options
author | Justin Lebar <jlebar@google.com> | 2017-12-20 14:13:37 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-20 14:17:02 -0800 |
commit | 1279bb10b9bd76f15637074c6518a3464916e007 (patch) | |
tree | 8baf3b1a58b5294f95eece3be61a7f59b4e95fb3 /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | |
parent | bd9f9d71df99fcdaf47326b0c81d79d2b2681fee (diff) |
[XLA] Add reassociation rule for adds, and add canonicalization rules for add/sub of a constant.
This patch adds a new algebraic simplification:
* (A + C0) + C1 => A + (C0 + C1), where C0 and C1 are constants. This
allows us to constant-fold C0 + C1.
In service of this rule, this patch also adds two new canonicalizations:
* Const + A => A + Const
* A - Const => A + (-1 * Const)
PiperOrigin-RevId: 179731747
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 71 |
1 files changed, 71 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 175d4d8d7f..48e822e3af 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -71,6 +71,55 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { EXPECT_EQ(root, param0); } +// Test that Const + A is canonicalized to A + Const. +TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Constant())); +} + +// Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. +TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -139,6 +188,28 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { EXPECT_EQ(root, param0); } +// Test that A - Const is canonicalized to A + (-Const). +TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kSubtract, param0, constant)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); |