aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-02 08:23:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-02 08:28:43 -0700
commitb1c10555afe9ad4ebebbd83eb31dbf8006d7980b (patch)
tree55cd7fad8d74663db40f203f6285c8af634f96bd /tensorflow/compiler/xla/service/hlo_computation_test.cc
parenta619fcaf4eaf3bb3a4099206a3a0631167db7f49 (diff)
Set the HloModule parent for HloComputations in all instances where
the parent field is currently null. PiperOrigin-RevId: 163983198
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc88
1 files changed, 50 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index 4a4a855669..203c28f0ab 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -65,45 +65,53 @@ class HloComputationTest : public HloTestBase {
};
TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) {
- auto negate_computation = CreateNegateComputation();
+ auto module = CreateNewModule();
+ auto negate_computation =
+ module->AddEntryComputation(CreateNegateComputation());
EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
}
TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) {
// Create computation which calls one other computation.
- auto negate_computation = CreateNegateComputation();
- auto map_computation = CreateMapComputation(negate_computation.get());
+ auto module = CreateNewModule();
+ auto negate_computation =
+ module->AddEmbeddedComputation(CreateNegateComputation());
+ auto map_computation =
+ module->AddEntryComputation(CreateMapComputation(negate_computation));
EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty());
EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(),
- ElementsAre(negate_computation.get()));
+ ElementsAre(negate_computation));
}
TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) {
// Create computations with a diamond-shaped callgraph.
- auto negate_computation = CreateNegateComputation();
- auto map1_computation = CreateMapComputation(negate_computation.get());
- auto map2_computation = CreateMapComputation(negate_computation.get());
+ auto module = CreateNewModule();
+ auto negate_computation =
+ module->AddEmbeddedComputation(CreateNegateComputation());
+ auto map1_computation =
+ module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
+ auto map2_computation =
+ module->AddEmbeddedComputation(CreateMapComputation(negate_computation));
auto builder = HloComputation::Builder(TestName());
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto map1 = builder.AddInstruction(
- HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get()));
+ HloInstruction::CreateMap(r0f32_, {param}, map1_computation));
auto map2 = builder.AddInstruction(
- HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get()));
+ HloInstruction::CreateMap(r0f32_, {param}, map2_computation));
builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2));
- auto computation = builder.Build();
+ auto computation = module->AddEntryComputation(builder.Build());
auto embedded_computations = computation->MakeEmbeddedComputationsList();
EXPECT_EQ(3, embedded_computations.size());
// GetEmbeddedComputations returns a post order of the embedded computations,
// so the negate computation must come first.
- EXPECT_EQ(negate_computation.get(), *embedded_computations.begin());
- EXPECT_THAT(
- embedded_computations,
- UnorderedElementsAre(negate_computation.get(), map1_computation.get(),
- map2_computation.get()));
+ EXPECT_EQ(negate_computation, *embedded_computations.begin());
+ EXPECT_THAT(embedded_computations,
+ UnorderedElementsAre(negate_computation, map1_computation,
+ map2_computation));
}
TEST_F(HloComputationTest, PostOrderSingleton) {
@@ -111,8 +119,8 @@ TEST_F(HloComputationTest, PostOrderSingleton) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant));
}
@@ -126,8 +134,8 @@ TEST_F(HloComputationTest, PostOrderSimple) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(),
ElementsAre(constant, negate1, negate2));
}
@@ -143,8 +151,8 @@ TEST_F(HloComputationTest, PostOrderTrace) {
builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1));
auto negate2 = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
// Trace instructions should be at the end of the sort.
EXPECT_THAT(computation->MakeInstructionPostOrder(),
ElementsAre(constant, negate1, negate2, trace));
@@ -162,8 +170,8 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
auto constant4 = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_THAT(computation->MakeInstructionPostOrder(),
UnorderedElementsAre(constant1, constant2, constant3, constant4));
}
@@ -184,8 +192,8 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) {
r0f32_, HloOpcode::kAdd, constant2, constant3));
auto add3 = builder.AddInstruction(HloInstruction::CreateBinary(
r0f32_, HloOpcode::kAdd, constant1, constant3));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
auto post_order = computation->MakeInstructionPostOrder();
EXPECT_EQ(6, post_order.size());
EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3,
@@ -209,8 +217,8 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
constant2, constant3));
builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd,
constant1, constant3));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
// Visitor which keeps track of which instructions have been visited.
class TestVisitor : public DfsHloVisitorWithDefault {
public:
@@ -236,7 +244,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) {
HloInstruction* last_visited_ = nullptr;
};
- TestVisitor visitor(computation.get());
+ TestVisitor visitor(computation);
EXPECT_IS_OK(computation->Accept(&visitor));
EXPECT_EQ(6, visitor.visited_set_.size());
@@ -249,8 +257,8 @@ TEST_F(HloComputationTest, DeepCopyArray) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR1<float>({1.0, 2.0, 3.0})));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
auto copy = computation->DeepCopyInstruction(constant).ValueOrDie();
EXPECT_THAT(copy, op::Copy(constant));
@@ -266,8 +274,8 @@ TEST_F(HloComputationTest, DeepCopyTuple) {
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({constant1, constant2}));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie();
EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)),
@@ -285,8 +293,8 @@ TEST_F(HloComputationTest, CycleDetection) {
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
// Add a control dependency to create a cycle.
ASSERT_IS_OK(add->AddControlDependencyTo(negate));
@@ -310,8 +318,8 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) {
r0f32_, HloOpcode::kAdd, dead_negate, dead_negate));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant));
- auto computation = builder.Build();
-
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
EXPECT_EQ(4, computation->instruction_count());
EXPECT_THAT(computation->root_instruction(), op::Negate(constant));
EXPECT_EQ(negate, computation->root_instruction());
@@ -336,7 +344,9 @@ TEST_F(HloComputationTest, CloneWithControlDependency) {
HloInstruction::CreateParameter(0, r0f32_, "param0"));
auto negate = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param));
- auto computation = builder.Build(/*root_instruction=*/add);
+ auto module = CreateNewModule();
+ auto computation =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/add));
TF_CHECK_OK(negate->AddControlDependencyTo(add));
@@ -384,7 +394,9 @@ TEST_F(HloComputationTest, Reachability) {
auto copy = builder.AddInstruction(
HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp));
- auto computation = builder.Build(/*root_instruction=*/mul);
+ auto module = CreateNewModule();
+ auto computation =
+ module->AddEntryComputation(builder.Build(/*root_instruction=*/mul));
TF_CHECK_OK(add->AddControlDependencyTo(exp));
auto reachability = computation->ComputeReachability();