aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/call_inliner_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/call_inliner_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/call_inliner_test.cc12
1 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc
index 5d85a3f173..e6b5665435 100644
--- a/tensorflow/compiler/xla/service/call_inliner_test.cc
+++ b/tensorflow/compiler/xla/service/call_inliner_test.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -40,7 +40,7 @@ namespace {
// Tests for call inlining that are most tractable at the HLO level (vs
// ComputationBuilder API in call_test.cc).
-using CallInlinerTest = HloTestBase;
+using CallInlinerTest = HloVerifiedTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
@@ -64,7 +64,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement<float>(),
@@ -92,6 +92,8 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, pred, "param"));
+ call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
@@ -105,7 +107,7 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
@@ -161,7 +163,7 @@ TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
- TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module));
ASSERT_TRUE(mutated);
}