aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-06-04 22:09:11 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-04 22:11:43 -0700
commitd660ab0c392562be89f02400e492bd54a7f9d6b0 (patch)
tree5ccc1c0acaca41ef84ff12f3b1398b7923d1930a
parentfedfc47ca6713adbbf82e10d4803c5fe94234bbd (diff)
[TF:XLA] Add method CreateNewModule to HloVerifiedTestBase, and remember all created modules, to verify at TearDown.
PiperOrigin-RevId: 199244092
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc47
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc20
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h16
3 files changed, 51 insertions, 32 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index cda157f9fa..27eb48181e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1714,7 +1714,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1759,7 +1759,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1781,7 +1781,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1804,7 +1804,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1932,7 +1932,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@@ -2060,7 +2061,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2090,7 +2091,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2121,7 +2122,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2151,7 +2152,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2184,7 +2185,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2200,10 +2201,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, scalar_param,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
@@ -2219,10 +2218,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2237,10 +2236,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, forty_two,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
HloInstruction* transpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -2259,7 +2256,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2268,7 +2265,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2349,7 +2347,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2444,7 +2443,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index c8a05c2e9e..22c664d142 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() {
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- VerifyModule();
+ VerifyModule(module_.get());
+ }
+ for (int i = 0; i < modules_.size(); ++i) {
+ VerifyModule(modules_.at(i).get());
}
HloTestBase::TearDown();
}
-void HloVerifiedTestBase::VerifyModule() {
- HloVerifier verifier;
- xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+void HloVerifiedTestBase::VerifyModule(HloModule* module) {
+ HloVerifier verifier(/*allow_mixed_precision=*/true);
+ xla::StatusOr<bool> mutated = verifier.Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() {
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = CreateNewModule();
+ module_ = HloTestBase::CreateNewModule();
}
return *module_;
}
+HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
+ modules_.emplace_back(HloTestBase::CreateNewModule());
+ return modules_.back().get();
+}
+
void HloVerifiedTestBase::ParseAndVerifyModule(
tensorflow::StringPiece hlo_text) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
- VerifyModule();
+ VerifyModule(module_.get());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index e5bb14a883..5b59cc77f6 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase {
shape_verifier_ = std::move(shape_verifier);
}
+ // Creates a new module for a test, and stores it in modules_ so it can be
+ // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
+ // creation of unverified modules.
+ HloModule* CreateNewModule(const string& name = TestName());
+
+ // It is confusing to store modules created by module() and CreateNewModule()
+ // in different fields, but it allows us to migrate tests to
+ // HloVerifiedTestBase more easily, so it's a win because we can verify more
+ // modules. See b/80488902.
private:
- std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
+ // Lazily populated. Access via module().
+ std::unique_ptr<HloModule> module_;
+ // Populated by calls to CreateNewModule.
+ std::vector<std::unique_ptr<HloModule>> modules_;
std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
- void VerifyModule();
+ static void VerifyModule(HloModule* module);
};
} // namespace xla