diff options
author | Mark Heffernan <meheff@google.com> | 2017-08-02 08:23:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-02 08:28:43 -0700 |
commit | b1c10555afe9ad4ebebbd83eb31dbf8006d7980b (patch) | |
tree | 55cd7fad8d74663db40f203f6285c8af634f96bd /tensorflow/compiler/xla/service/hlo_computation_test.cc | |
parent | a619fcaf4eaf3bb3a4099206a3a0631167db7f49 (diff) |
Set the HloModule parent for HloComputations in all instances where
the parent field is currently null.
PiperOrigin-RevId: 163983198
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation_test.cc | 88 |
1 files changed, 50 insertions, 38 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 4a4a855669..203c28f0ab 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -65,45 +65,53 @@ class HloComputationTest : public HloTestBase { }; TEST_F(HloComputationTest, GetEmbeddedComputationsEmpty) { - auto negate_computation = CreateNegateComputation(); + auto module = CreateNewModule(); + auto negate_computation = + module->AddEntryComputation(CreateNegateComputation()); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); } TEST_F(HloComputationTest, GetEmbeddedComputationsOneComputation) { // Create computation which calls one other computation. - auto negate_computation = CreateNegateComputation(); - auto map_computation = CreateMapComputation(negate_computation.get()); + auto module = CreateNewModule(); + auto negate_computation = + module->AddEmbeddedComputation(CreateNegateComputation()); + auto map_computation = + module->AddEntryComputation(CreateMapComputation(negate_computation)); EXPECT_TRUE(negate_computation->MakeEmbeddedComputationsList().empty()); EXPECT_THAT(map_computation->MakeEmbeddedComputationsList(), - ElementsAre(negate_computation.get())); + ElementsAre(negate_computation)); } TEST_F(HloComputationTest, GetEmbeddedComputationsDiamond) { // Create computations with a diamond-shaped callgraph. - auto negate_computation = CreateNegateComputation(); - auto map1_computation = CreateMapComputation(negate_computation.get()); - auto map2_computation = CreateMapComputation(negate_computation.get()); + auto module = CreateNewModule(); + auto negate_computation = + module->AddEmbeddedComputation(CreateNegateComputation()); + auto map1_computation = + module->AddEmbeddedComputation(CreateMapComputation(negate_computation)); + auto map2_computation = + module->AddEmbeddedComputation(CreateMapComputation(negate_computation)); auto builder = HloComputation::Builder(TestName()); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32_, "param0")); auto map1 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map1_computation.get())); + HloInstruction::CreateMap(r0f32_, {param}, map1_computation)); auto map2 = builder.AddInstruction( - HloInstruction::CreateMap(r0f32_, {param}, map2_computation.get())); + HloInstruction::CreateMap(r0f32_, {param}, map2_computation)); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, map1, map2)); - auto computation = builder.Build(); + auto computation = module->AddEntryComputation(builder.Build()); auto embedded_computations = computation->MakeEmbeddedComputationsList(); EXPECT_EQ(3, embedded_computations.size()); // GetEmbeddedComputations returns a post order of the embedded computations, // so the negate computation must come first. - EXPECT_EQ(negate_computation.get(), *embedded_computations.begin()); - EXPECT_THAT( - embedded_computations, - UnorderedElementsAre(negate_computation.get(), map1_computation.get(), - map2_computation.get())); + EXPECT_EQ(negate_computation, *embedded_computations.begin()); + EXPECT_THAT(embedded_computations, + UnorderedElementsAre(negate_computation, map1_computation, + map2_computation)); } TEST_F(HloComputationTest, PostOrderSingleton) { @@ -111,8 +119,8 @@ TEST_F(HloComputationTest, PostOrderSingleton) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); } @@ -126,8 +134,8 @@ TEST_F(HloComputationTest, PostOrderSimple) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2)); } @@ -143,8 +151,8 @@ TEST_F(HloComputationTest, PostOrderTrace) { builder.AddInstruction(HloInstruction::CreateTrace("foobar", negate1)); auto negate2 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, negate1)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); // Trace instructions should be at the end of the sort. EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant, negate1, negate2, trace)); @@ -162,8 +170,8 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); auto constant4 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f))); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_THAT(computation->MakeInstructionPostOrder(), UnorderedElementsAre(constant1, constant2, constant3, constant4)); } @@ -184,8 +192,8 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { r0f32_, HloOpcode::kAdd, constant2, constant3)); auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); auto post_order = computation->MakeInstructionPostOrder(); EXPECT_EQ(6, post_order.size()); EXPECT_THAT(post_order, UnorderedElementsAre(constant1, constant2, constant3, @@ -209,8 +217,8 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { constant2, constant3)); builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant3)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); // Visitor which keeps track of which instructions have been visited. class TestVisitor : public DfsHloVisitorWithDefault { public: @@ -236,7 +244,7 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { HloInstruction* last_visited_ = nullptr; }; - TestVisitor visitor(computation.get()); + TestVisitor visitor(computation); EXPECT_IS_OK(computation->Accept(&visitor)); EXPECT_EQ(6, visitor.visited_set_.size()); @@ -249,8 +257,8 @@ TEST_F(HloComputationTest, DeepCopyArray) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR1<float>({1.0, 2.0, 3.0}))); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); EXPECT_THAT(copy, op::Copy(constant)); @@ -266,8 +274,8 @@ TEST_F(HloComputationTest, DeepCopyTuple) { auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); auto tuple_copy = computation->DeepCopyInstruction(tuple).ValueOrDie(); EXPECT_THAT(tuple_copy, op::Tuple(op::Copy(op::GetTupleElement(tuple)), @@ -285,8 +293,8 @@ TEST_F(HloComputationTest, CycleDetection) { HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, negate, negate)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); // Add a control dependency to create a cycle. ASSERT_IS_OK(add->AddControlDependencyTo(negate)); @@ -310,8 +318,8 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { r0f32_, HloOpcode::kAdd, dead_negate, dead_negate)); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); - auto computation = builder.Build(); - + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); EXPECT_EQ(4, computation->instruction_count()); EXPECT_THAT(computation->root_instruction(), op::Negate(constant)); EXPECT_EQ(negate, computation->root_instruction()); @@ -336,7 +344,9 @@ TEST_F(HloComputationTest, CloneWithControlDependency) { HloInstruction::CreateParameter(0, r0f32_, "param0")); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, param)); - auto computation = builder.Build(/*root_instruction=*/add); + auto module = CreateNewModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/add)); TF_CHECK_OK(negate->AddControlDependencyTo(add)); @@ -384,7 +394,9 @@ TEST_F(HloComputationTest, Reachability) { auto copy = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kCopy, exp)); - auto computation = builder.Build(/*root_instruction=*/mul); + auto module = CreateNewModule(); + auto computation = + module->AddEntryComputation(builder.Build(/*root_instruction=*/mul)); TF_CHECK_OK(add->AddControlDependencyTo(exp)); auto reachability = computation->ComputeReachability(); |