diff options
author | Dimitris Vardoulakis <dimvar@google.com> | 2018-04-28 22:19:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-28 22:21:47 -0700 |
commit | d02745e20c02ba7506a920cc4c8b00415f82ee79 (patch) | |
tree | 1caeccbacfa521d13cf790537bfb24c2e7f7f081 /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | |
parent | 3a9c513c3f4303e5194474d804367c1f4831e3ee (diff) |
[TF:XLA]
- Require a module config when creating an HloModule.
- All tests using HloTestBase create a module using CreateNewModule.
PiperOrigin-RevId: 194684585
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/algebraic_simplifier_test.cc | 101 |
1 files changed, 51 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 20c549562d..d0c99bf818 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" @@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); @@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero)); EXPECT_TRUE(has_negative_padding(pad)); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero))); EXPECT_FALSE( @@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param)); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) { ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0}, /*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1})); - HloModule module(TestName()); - HloComputation* computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Slice(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param); } @@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter, window, dnums)); - HloModule module(TestName()); - auto* computation = module.AddEntryComputation(b.Build()); + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(b.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - if (!simplifier.Run(&module).ValueOrDie()) { + if (!simplifier.Run(module.get()).ValueOrDie()) { return "NO_CHANGE"; } auto* root = computation->root_instruction(); @@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Minimum(param0, min_value), max_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Clamp(max_value, param0, min_value)); @@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Maximum(param0, max_value), min_value)); @@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, fmax, min_value)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(&module).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Minimum(op::Add(op::Maximum(param0, max_value), max_value), @@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice( slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, slice); @@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(scalar_param)); @@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloInstruction* reshape = builder.AddInstruction( HloInstruction::CreateReshape(reshape_shape, transpose)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reshape); @@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast(forty_two)); @@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { // Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x). TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { // Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to // ReduceWindow(Convert(op), x). TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { - HloModule module(TestName()); + auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); // Create operand to the pad. @@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { HloInstruction::CreateParameter(1, scalar_shape, "p1")); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); - add_computation = module.AddEmbeddedComputation(builder.Build()); + add_computation = module->AddEmbeddedComputation(builder.Build()); } // Create the reduce-window. @@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) { add_computation)); // Build the computation and run the simplifier. - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, reduce_window); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); // Running simplification again should not result in any further changes. - ASSERT_FALSE(simplifier.Run(&module).ValueOrDie()); + ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie()); // Verify the result root = computation->root_instruction(); @@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) { builder.AddInstruction( HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(&module).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(a, root); |