diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/call_inliner_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/call_inliner_test.cc | 12 |
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 5d85a3f173..e6b5665435 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -40,7 +40,7 @@ namespace { // Tests for call inlining that are most tractable at the HLO level (vs // ComputationBuilder API in call_test.cc). -using CallInlinerTest = HloTestBase; +using CallInlinerTest = HloVerifiedTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { // "inner" computation just has a control dependency from the "zero" value to @@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(), @@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { HloComputation::Builder call_false_builder(TestName() + ".call_false"); call_false_builder.AddInstruction( + HloInstruction::CreateParameter(0, pred, "param")); + call_false_builder.AddInstruction( HloInstruction::CreateCall(pred, {}, false_computation)); HloComputation* call_false = module->AddEmbeddedComputation(call_false_builder.Build()); @@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { auto computation = module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); EXPECT_THAT( computation->root_instruction()->while_condition()->root_instruction(), @@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { module->AddEntryComputation(outer.Build()); CallInliner call_inliner; - TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module)); ASSERT_TRUE(mutated); } |