diff options
author | Mark Heffernan <meheff@google.com> | 2017-11-03 13:26:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-03 13:31:12 -0700 |
commit | 456929281592f14d50443cfbdaa2f6b36167a134 (patch) | |
tree | 8a1a18245a6ef9baed8bd9a9f35b7c250ab64901 /tensorflow/compiler/xla/service/buffer_assignment_test.cc | |
parent | 5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff) |
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC
BEGIN_PUBLIC
Automated g4 rollback of changelist 174423881
PiperOrigin-RevId: 174505237
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment_test.cc | 78 |
1 files changed, 64 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 4d4c5b953e..89410f42bd 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1538,6 +1538,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto cond0 = module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); @@ -1554,8 +1556,10 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto body1 = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output1})); auto while1 = builder.AddInstruction( - HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0)); + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); module->AddEntryComputation(builder.Build()); RunCopyInsertion(module.get()); @@ -1672,37 +1676,34 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { auto while1 = builder.AddInstruction( HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); - auto gte0 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while0, 0)); - auto gte1 = builder.AddInstruction( - HloInstruction::CreateGetTupleElement(data_shape_, while1, 1)); auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( - while0->shape(), HloOpcode::kAdd, gte0, gte1)); - + while0->shape(), HloOpcode::kAdd, while0, while1)); module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + { FlattenCallGraph flatten; TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); EXPECT_TRUE(result); } - RunCopyInsertion(module.get()); - auto sequence = CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); // To trigger b/38494731, we want a specific Hlo sequence for the // root computation, so we overwrite that entry with a manually // crafted sequence. - sequence[module->entry_computation()] = { - input1, weights1, one, output1, while1->operand(0), while1, - input0, weights0, zero, output0, while0->operand(0), while0, - gte0, gte1, root_add}; + std::vector<const HloInstruction*> sequence_for_buffer_assigment = { + input1, weights1, one, output1, tuple1, while1, input0, + weights0, zero, output0, tuple0, while0, root_add}; // If this ASSERT_TRUE fails, we constructed a bogus sequence above // and this test itself is buggy. - ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()])); + ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); + + sequence[module->entry_computation()] = + std::move(sequence_for_buffer_assigment); auto assignment = BufferAssigner::Run( @@ -1714,6 +1715,55 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } +// Test buffer assignment for while nodes with multiple uses. +// TODO(b/37245345): Fix buffer assignment for this case. +TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { + auto module = MakeUnique<HloModule>(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0)); + + auto get0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + auto get1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto assignment = RunBufferAssignment(module.get()); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { auto module = MakeUnique<HloModule>(TestName()); auto builder = HloComputation::Builder("entry"); |