aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/batchnorm_expander_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander_test.cc14
1 files changed, 7 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
index aba0d9bb5b..f7ac8f5482 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc
@@ -29,14 +29,14 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
-using BatchNormExpanderTest = HloTestBase;
+using BatchNormExpanderTest = HloVerifiedTestBase;
// Test that we expand BatchNormTraining.
TEST_F(BatchNormExpanderTest, BatchNormTraining) {
@@ -66,7 +66,7 @@ TEST_F(BatchNormExpanderTest, BatchNormTraining) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -108,7 +108,7 @@ TEST_F(BatchNormExpanderTest, BatchNormGrad) {
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(module).ValueOrDie());
root = computation->root_instruction();
// Make sure this operation is expanded.
EXPECT_EQ(root->opcode(), HloOpcode::kTuple);
@@ -126,13 +126,13 @@ ENTRY entry {
epsilon=0.001, feature_index=1, sharding={maximal device=1}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(module_str));
+ ParseAndVerifyModule(module_str);
BatchNormExpander rewriter(/*rewrite_training_op=*/true,
/*rewrite_inference_op=*/true,
/*rewrite_grad_op=*/true);
- ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie());
+ ASSERT_TRUE(rewriter.Run(&module()).ValueOrDie());
- for (auto* instruction : module->entry_computation()->instructions()) {
+ for (auto* instruction : module().entry_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kParameter) {
continue;
}