aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-04-28 22:19:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-28 22:21:47 -0700
commitd02745e20c02ba7506a920cc4c8b00415f82ee79 (patch)
tree1caeccbacfa521d13cf790537bfb24c2e7f7f081 /tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
parent3a9c513c3f4303e5194474d804367c1f4831e3ee (diff)
[TF:XLA]
- Require a module config when creating an HloModule. - All tests using HloTestBase create a module using CreateNewModule. PiperOrigin-RevId: 194684585
Diffstat (limited to 'tensorflow/compiler/xla/service/algebraic_simplifier_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc101
1 files changed, 51 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 20c549562d..d0c99bf818 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/window_util.h"
@@ -1699,14 +1700,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {2, 2}), param, zero, no_padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1732,8 +1733,8 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad(
ShapeUtil::MakeShape(F32, {11, 5}), param, zero, padding));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
@@ -1751,7 +1752,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1766,14 +1767,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
builder.AddInstruction(
HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {2, 3}), param));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Reshape(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1789,14 +1790,14 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
ShapeUtil::MakeShape(F32, {dim0, dim1}), param, /*start_indices=*/{0, 0},
/*limit_indices=*/{dim0, dim1}, /*strides=*/{1, 1}));
- HloModule module(TestName());
- HloComputation* computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ HloComputation* computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(), op::Slice(param));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1924,12 +1925,12 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- HloModule module(TestName());
- auto* computation = module.AddEntryComputation(b.Build());
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
bitcasting_callback());
- if (!simplifier.Run(&module).ValueOrDie()) {
+ if (!simplifier.Run(module.get()).ValueOrDie()) {
return "NO_CHANGE";
}
auto* root = computation->root_instruction();
@@ -2044,15 +2045,15 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMaximum, min, max_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Maximum(op::Minimum(param0, min_value), max_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2074,15 +2075,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2105,15 +2106,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r1f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2135,15 +2136,15 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32, HloOpcode::kMinimum, max, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2167,8 +2168,8 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
builder.AddInstruction(HloInstruction::CreateBinary(
r0f32, HloOpcode::kMinimum, fmax, min_value));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2176,7 +2177,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(&module).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2201,8 +2202,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
slice_shape, broadcast, {0, 1, 2, 3}, {2, 3, 5, 6}, {1, 1, 1, 1}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, slice);
@@ -2211,10 +2212,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2242,8 +2243,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction* reshape = builder.AddInstruction(
HloInstruction::CreateReshape(reshape_shape, transpose));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reshape);
@@ -2251,7 +2252,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2260,7 +2261,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2289,7 +2290,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
@@ -2312,15 +2313,15 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
@@ -2341,7 +2342,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- HloModule module(TestName());
+ auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2374,7 +2375,7 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
HloInstruction::CreateParameter(1, scalar_shape, "p1"));
builder.AddInstruction(
HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1));
- add_computation = module.AddEmbeddedComputation(builder.Build());
+ add_computation = module->AddEmbeddedComputation(builder.Build());
}
// Create the reduce-window.
@@ -2397,15 +2398,15 @@ TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
add_computation));
// Build the computation and run the simplifier.
- auto computation = module.AddEntryComputation(builder.Build());
+ auto computation = module->AddEntryComputation(builder.Build());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(root, reduce_window);
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
// Verify the result
root = computation->root_instruction();
@@ -2431,12 +2432,12 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
builder.AddInstruction(
HloInstruction::CreateReverse(shape, a, /*dimensions=*/{2, 3}));
- HloModule module(TestName());
- auto computation = module.AddEntryComputation(builder.Build());
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(&module).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);