aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-06-08 10:02:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-08 10:05:35 -0700
commit0ef76693fdab2a4d1a4923444a2593f79a6b7873 (patch)
tree1fbfadc281719b4513091dde8d1062943dc75568
parent8566ebe58ff5b08864ddef6fe743fdd80962465b (diff)
Automated g4 rollback of changelist 199308328
PiperOrigin-RevId: 199809082
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc47
-rw-r--r--tensorflow/compiler/xla/tests/hlo_test_base.h17
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.cc20
-rw-r--r--tensorflow/compiler/xla/tests/hlo_verified_test_base.h16
4 files changed, 60 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index cda157f9fa..27eb48181e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -1714,7 +1714,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1759,7 +1759,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) {
EXPECT_THAT(computation->root_instruction(), op::Pad(param, zero));
EXPECT_TRUE(has_negative_padding(pad));
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), op::Slice(op::Pad(param, zero)));
EXPECT_FALSE(
@@ -1781,7 +1781,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1804,7 +1804,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(), param);
}
@@ -1932,7 +1932,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) {
b.AddInstruction(HloInstruction::CreateConvolve(out_shape, input, filter,
window, dnums));
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
auto* computation = module->AddEntryComputation(b.Build());
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true,
@@ -2060,7 +2061,7 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2090,7 +2091,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2121,7 +2122,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Clamp(max_value, param0, min_value));
@@ -2151,7 +2152,7 @@ TEST_F(AlgebraicSimplifierTest, MinMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Maximum(param0, max_value), min_value));
@@ -2184,7 +2185,7 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ EXPECT_FALSE(simplifier.Run(module).ValueOrDie());
EXPECT_THAT(computation->root_instruction(),
op::Minimum(op::Add(op::Maximum(param0, max_value), max_value),
@@ -2200,10 +2201,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
HloInstruction::CreateParameter(0, r0f32, "scalar_param"));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, scalar_param,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, scalar_param, {}));
Shape slice_shape = ShapeUtil::MakeShape(F32, {2, 2, 3, 3});
HloInstruction* slice = builder.AddInstruction(HloInstruction::CreateSlice(
@@ -2219,10 +2218,10 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
// Running simplification again should not result in any further changes.
- ASSERT_FALSE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_FALSE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(scalar_param));
@@ -2237,10 +2236,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6});
- HloInstruction* broadcast =
- builder.AddInstruction(HloInstruction::CreateBroadcast(
- broadcast_shape, forty_two,
- AsInt64Slice(broadcast_shape.dimensions())));
+ HloInstruction* broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(broadcast_shape, forty_two, {}));
HloInstruction* transpose =
builder.AddInstruction(HloInstruction::CreateTranspose(
@@ -2259,7 +2256,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
root = computation->root_instruction();
EXPECT_THAT(root, op::Broadcast(forty_two));
@@ -2268,7 +2265,8 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) {
// Test that ReduceWindow(Pad(op, x), y) can simplify to ReduceWindow(op, x).
TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2349,7 +2347,8 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) {
// Test that ReduceWindow(Convert(Pad(op, x)), y) can simplify to
// ReduceWindow(Convert(op), x).
TEST_F(AlgebraicSimplifierTest, FoldConvertedPadIntoReduceWindow) {
- auto module = CreateNewModule();
+ // TODO(b/80488902): verify this module.
+ auto module = HloTestBase::CreateNewModule();
HloComputation::Builder builder(TestName());
// Create operand to the pad.
@@ -2444,7 +2443,7 @@ TEST_F(AlgebraicSimplifierTest, ReversalOfTrivialDimensionsToBitcast) {
AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false,
non_bitcasting_callback());
- ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(simplifier.Run(module).ValueOrDie());
HloInstruction* root = computation->root_instruction();
EXPECT_EQ(a, root);
diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h
index eb3a2ea76a..249da87f48 100644
--- a/tensorflow/compiler/xla/tests/hlo_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_test_base.h
@@ -66,6 +66,15 @@ namespace xla {
//
// For a more detailed example, see "../tests/sample_text_test.cc".
class HloTestBase : public ::testing::Test {
+ public:
+ // Creates a new HLO module for a test. The module created will have
+ // TestName() for its name; it will also automatically populate its debug
+ // options from command-line flags. If you want a fresh HloModule object and
+ // then add HloComputations to it, it's recommended to use this method in your
+ // tests.
+ static std::unique_ptr<HloModule> CreateNewModule(
+ const string& name = TestName());
+
protected:
// This uses the interpreter backend as the reference backend and
// automatically finds another supported backend as the test backend. If the
@@ -80,14 +89,6 @@ class HloTestBase : public ::testing::Test {
~HloTestBase() override {}
- // Creates a new HLO module for a test. The module created will have
- // TestName() for its name; it will also automatically populate its debug
- // options from command-line flags. If you want a fresh HloModule object and
- // then add HloComputations to it, it's recommended to use this method in your
- // tests.
- static std::unique_ptr<HloModule> CreateNewModule(
- const string& name = TestName());
-
// Populates debug options from command-line flags and adjusts the options for
// testing. It is recommended to use this when you need to pass in
// DebugOptions, e.g. when creating a module from a string or a file.
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
index c8a05c2e9e..22c664d142 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc
@@ -41,14 +41,17 @@ void HloVerifiedTestBase::TearDown() {
<< "TearDown called more than once; it should be called exactly once.";
tear_down_called_ = true;
if (module_) {
- VerifyModule();
+ VerifyModule(module_.get());
+ }
+ for (int i = 0; i < modules_.size(); ++i) {
+ VerifyModule(modules_.at(i).get());
}
HloTestBase::TearDown();
}
-void HloVerifiedTestBase::VerifyModule() {
- HloVerifier verifier;
- xla::StatusOr<bool> mutated = verifier.Run(module_.get());
+void HloVerifiedTestBase::VerifyModule(HloModule* module) {
+ HloVerifier verifier(/*allow_mixed_precision=*/true);
+ xla::StatusOr<bool> mutated = verifier.Run(module);
if (!mutated.ok()) {
ADD_FAILURE() << "HloVerifier failed: " << mutated.status();
} else {
@@ -59,15 +62,20 @@ void HloVerifiedTestBase::VerifyModule() {
HloModule& HloVerifiedTestBase::module() {
if (!module_) {
- module_ = CreateNewModule();
+ module_ = HloTestBase::CreateNewModule();
}
return *module_;
}
+HloModule* HloVerifiedTestBase::CreateNewModule(const string& name) {
+ modules_.emplace_back(HloTestBase::CreateNewModule());
+ return modules_.back().get();
+}
+
void HloVerifiedTestBase::ParseAndVerifyModule(
tensorflow::StringPiece hlo_text) {
CHECK(!module_) << "Called ParseModule when test already has a module.";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseHloString(hlo_text));
- VerifyModule();
+ VerifyModule(module_.get());
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
index e5bb14a883..5b59cc77f6 100644
--- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
+++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h
@@ -52,11 +52,23 @@ class HloVerifiedTestBase : public HloTestBase {
shape_verifier_ = std::move(shape_verifier);
}
+ // Creates a new module for a test, and stores it in modules_ so it can be
+ // verified. Intentionally hides HloTestBase::CreateNewModule, to prevent
+ // creation of unverified modules.
+ HloModule* CreateNewModule(const string& name = TestName());
+
+ // It is confusing to store modules created by module() and CreateNewModule()
+ // in different fields, but it allows us to migrate tests to
+ // HloVerifiedTestBase more easily, so it's a win because we can verify more
+ // modules. See b/80488902.
private:
- std::unique_ptr<HloModule> module_; // Lazily populated. Access via module().
+ // Lazily populated. Access via module().
+ std::unique_ptr<HloModule> module_;
+ // Populated by calls to CreateNewModule.
+ std::vector<std::unique_ptr<HloModule>> modules_;
std::unique_ptr<ShapeVerifier> shape_verifier_;
bool tear_down_called_ = false;
- void VerifyModule();
+ static void VerifyModule(HloModule* module);
};
} // namespace xla