aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_cse_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_cse_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse_test.cc428
1 files changed, 428 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc
new file mode 100644
index 0000000000..ec8161f55f
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc
@@ -0,0 +1,428 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_cse.h"
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/layout_util.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace xla {
+namespace {
+
+class HloCseTest : public HloTestBase {
+ protected:
+ HloCseTest() {}
+};
+
+TEST_F(HloCseTest, CombineTwoConstants) {
+ // Test that two identical constants are commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(2, computation->instruction_count());
+ HloInstruction* constant = computation->instructions().begin()->get();
+ EXPECT_EQ(42.0f, LiteralUtil::Get<float>(constant->literal(), {}));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR0<float>(84.0);
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) {
+ // Test that two identical constants with different layouts are commoned if
+ // the pass is not layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{0, 1})));
+ auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{1, 0})));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_NE(add->operand(0), add->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(2, computation->instruction_count());
+ EXPECT_EQ(add->operand(0), add->operand(1));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) {
+ // Test that two identical constants with different layouts are *not* commoned
+ // if the pass is layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{0, 1})));
+ auto constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ test_utils::CreateR2LiteralWithLayout<float>({{1.0, 2.0}, {3.0, 4.0}},
+ /*minor_to_major=*/{1, 0})));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, constant1, constant2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(constant1, add->operand(0));
+ EXPECT_EQ(constant2, add->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/true);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(constant1, add->operand(0));
+ EXPECT_EQ(constant2, add->operand(1));
+
+ auto result = ExecuteAndTransfer(std::move(module), {});
+ auto expected = LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}});
+ LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4));
+}
+
+TEST_F(HloCseTest, ConstantsSameValueDifferentType) {
+ // Test that constants with the same value but different type are *not*
+ // commoned.
+ auto builder = HloComputation::Builder(TestName());
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint32>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(42)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<uint64>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<int64>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<double>(42.0)));
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+ // Duplicate the float constant to verify something happens.
+ builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0f)));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(7, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(6, computation->instruction_count());
+}
+
+TEST_F(HloCseTest, NonscalarConstants) {
+ // Test that identical nonscalar constants are merged.
+ auto builder = HloComputation::Builder(TestName());
+ auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+ // Create a constant which has the same shape but a different value.
+ auto uncommon_constant =
+ builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{2.0, 4.0}, {6.0, 8.0}})));
+
+ // Tie the constants together with a tuple. This makes it easier to refer to
+ // the constant instructions via their use.
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple(
+ {common_constant1, common_constant2, uncommon_constant}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(uncommon_constant, tuple->operand(2));
+ EXPECT_TRUE(tuple->operand(0) == common_constant1 ||
+ tuple->operand(0) == common_constant2);
+}
+
+TEST_F(HloCseTest, IdenticalInstructions) {
+ // Test that three identical instructions are commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto exp3 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2, exp3}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(5, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+ EXPECT_NE(tuple->operand(1), tuple->operand(2));
+ EXPECT_NE(tuple->operand(0), tuple->operand(2));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(tuple->operand(1), tuple->operand(2));
+}
+
+TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) {
+ // Test that two identical instructions with different layouts are *not*
+ // commoned if the pass is layout sensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/true);
+ EXPECT_FALSE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+}
+
+TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) {
+ // Test that two identical instructions with different layouts are commoned if
+ // the pass is layout insensitive.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
+ LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
+
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp1->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1});
+
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ *exp2->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0});
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({exp1, exp2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+}
+
+TEST_F(HloCseTest, IdenticalExpressions) {
+ // Test that two identical expressions are commoned. Build the following
+ // computation:
+ //
+ // constant = 42.0
+ // negate1 = neg(constant)
+ // exp1 = exp(constant)
+ // add1 = add(negate1, exp1)
+ // negate2 = neg(constant)
+ // exp2 = exp(constant)
+ // add2 = add(negate2, exp2)
+ // tuple = tuple(add1, add2)
+ //
+ // The *1 instructions should be merged with the *2 instructions.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
+
+ auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, negate1, exp1));
+
+ auto negate2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kNegate, constant));
+ auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary(
+ constant->shape(), HloOpcode::kExp, constant));
+ auto add2 = builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, negate2, exp2));
+
+ auto tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add1, add2}));
+
+ HloModule module(TestName());
+ auto computation = module.AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(8, computation->instruction_count());
+ EXPECT_NE(tuple->operand(0), tuple->operand(1));
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(&module).ValueOrDie());
+
+ EXPECT_EQ(5, computation->instruction_count());
+ EXPECT_EQ(tuple->operand(0), tuple->operand(1));
+ EXPECT_EQ(HloOpcode::kAdd, tuple->operand(0)->opcode());
+}
+
+TEST_F(HloCseTest, DoNotCombineRng) {
+ // Test that two RNG ops are not commoned.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ auto rng1 = builder.AddInstruction(HloInstruction::CreateRng(
+ ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
+ {constant1, constant2}));
+ auto rng2 = builder.AddInstruction(HloInstruction::CreateRng(
+ ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM,
+ {constant1, constant2}));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant1->shape(), HloOpcode::kAdd, rng1, rng2));
+
+ auto module = MakeUnique<HloModule>(TestName());
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ uint32 count_before = computation->instruction_count();
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_FALSE(cse.Run(module.get()).ValueOrDie());
+
+ uint32 count_after = computation->instruction_count();
+ EXPECT_EQ(count_before, count_after);
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kRng);
+ EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kRng);
+ EXPECT_NE(root->operand(0), root->operand(1));
+}
+
+// TODO(b/28245743): Handle impure functions correctly in CSE.
+TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) {
+ // Test that two calls to an impure function are not commoned. RNG
+ // is the source of the impurity.
+
+ auto module = MakeUnique<HloModule>(TestName());
+
+ // rng_function is an impure function because it does RNG.
+ HloComputation* rng_function = nullptr;
+ {
+ Shape scalar_shape = ShapeUtil::MakeShape(F32, {});
+ auto builder = HloComputation::Builder(TestName() + "_rng_fun");
+ auto constant1 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0.0f)));
+ auto constant2 = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
+ auto rng = builder.AddInstruction(HloInstruction::CreateRng(
+ scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2}));
+ auto param = builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {}), "param"));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, rng, param));
+ rng_function = module->AddEmbeddedComputation(builder.Build());
+ }
+
+ // Computation calls rng_function twice with the same parameter.
+ HloComputation* computation = nullptr;
+ {
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({5.0f})));
+ auto rng1 = builder.AddInstruction(
+ HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
+ auto rng2 = builder.AddInstruction(
+ HloInstruction::CreateMap(constant->shape(), {constant}, rng_function));
+ builder.AddInstruction(HloInstruction::CreateBinary(
+ constant->shape(), HloOpcode::kAdd, rng1, rng2));
+ computation = module->AddEntryComputation(builder.Build());
+ }
+
+ EXPECT_EQ(4, computation->instruction_count());
+
+ HloCSE cse(/*is_layout_sensitive=*/false);
+ EXPECT_TRUE(cse.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(4, computation->instruction_count());
+ HloInstruction* root = computation->root_instruction();
+ EXPECT_EQ(root->opcode(), HloOpcode::kAdd);
+ EXPECT_EQ(root->operand(0)->opcode(), HloOpcode::kMap);
+ EXPECT_EQ(root->operand(1)->opcode(), HloOpcode::kMap);
+ EXPECT_NE(root->operand(0), root->operand(1));
+}
+
+} // namespace
+} // namespace xla