diff options
author | 2018-09-14 09:21:08 -0700 | |
---|---|---|
committer | 2018-09-14 09:21:08 -0700 | |
commit | 41aaed7751690b0b3137dad2620656a698b3ceae (patch) | |
tree | 00fc1a7f6be0c3968f3e674a65ca4907110ddf2d /tensorflow/compiler/xla/service/batchnorm_expander_test.cc | |
parent | c26c5e1217944448f1f4c2b97626fc4d7d6406d3 (diff) | |
parent | 95338704198205c1bdec1e344e103f1daf05df68 (diff) |
Merge branch 'master' into avijit/add-cpu-backend
Diffstat (limited to 'tensorflow/compiler/xla/service/batchnorm_expander_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/batchnorm_expander_test.cc | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc index aba0d9bb5b..f7ac8f5482 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -29,14 +29,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" namespace xla { namespace { -using BatchNormExpanderTest = HloTestBase; +using BatchNormExpanderTest = HloVerifiedTestBase; // Test that we expand BatchNormTraining. TEST_F(BatchNormExpanderTest, BatchNormTraining) { @@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) { BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(module).ValueOrDie()); root = computation->root_instruction(); // Make sure this operation is expanded. EXPECT_EQ(root->opcode(), HloOpcode::kTuple); @@ -126,13 +126,13 @@ ENTRY entry { epsilon=0.001, feature_index=1, sharding={maximal device=1} })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str)); + ParseAndVerifyModule(module_str); BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); - ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie()); - for (auto* instruction : module->entry_computation()->instructions()) { + for (auto* instruction : module().entry_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kParameter) { continue; } |