aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/flatten_call_graph_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/flatten_call_graph_test.cc22
1 files changed, 11 insertions, 11 deletions
diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
index 8f6608241e..5fbd73a536 100644
--- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
+++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
-#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -30,7 +30,7 @@ limitations under the License.
namespace xla {
namespace {
-class FlattenCallGraphTest : public HloTestBase {
+class FlattenCallGraphTest : public HloVerifiedTestBase {
protected:
// Build and return a trivial computation taking and returning a scalar.
std::unique_ptr<HloComputation> MakeScalarComputation() {
@@ -139,9 +139,9 @@ TEST_F(FlattenCallGraphTest, ComplexGraph) {
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> flat_call_graph = CallGraph::Build(module);
const CallGraphNode& c_node = flat_call_graph->GetNode(c_computation);
EXPECT_EQ(1, c_node.caller_callsites().size());
}
@@ -176,15 +176,15 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) {
}
{
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(2, cond_node.caller_callsites().size());
}
{
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
const CallGraphNode& cond_node = call_graph->GetNode(cond_computation);
EXPECT_EQ(1, cond_node.caller_callsites().size());
}
@@ -211,9 +211,9 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) {
module->AddEntryComputation(
MakeCallingComputation(b_computation, /*callsites=*/2, ".Entry"));
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
EXPECT_EQ(7, module->computation_count());
const CallGraphNode& c_node = call_graph->GetNode(c_computation);
@@ -243,9 +243,9 @@ TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) {
module->AddEntryComputation(builder.Build());
EXPECT_EQ(2, module->computation_count());
- TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get()));
+ TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module));
EXPECT_TRUE(result);
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module.get());
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// The true and false computations must now be different.
EXPECT_EQ(3, module->computation_count());