aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/buffer_assignment_test.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-03 13:26:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 13:31:12 -0700
commit456929281592f14d50443cfbdaa2f6b36167a134 (patch)
tree8a1a18245a6ef9baed8bd9a9f35b7c250ab64901 /tensorflow/compiler/xla/service/buffer_assignment_test.cc
parent5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff)
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 174423881 PiperOrigin-RevId: 174505237
Diffstat (limited to 'tensorflow/compiler/xla/service/buffer_assignment_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc78
1 files changed, 64 insertions, 14 deletions
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 4d4c5b953e..89410f42bd 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1538,6 +1538,8 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
auto output0 = builder.AddInstruction(
HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+ auto output1 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
auto cond0 =
module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
@@ -1554,8 +1556,10 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto body1 =
module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+ auto tuple1 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output1}));
auto while1 = builder.AddInstruction(
- HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, while0));
+ HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1));
module->AddEntryComputation(builder.Build());
RunCopyInsertion(module.get());
@@ -1672,37 +1676,34 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto while1 = builder.AddInstruction(
HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1));
- auto gte0 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(data_shape_, while0, 0));
- auto gte1 = builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(data_shape_, while1, 1));
auto root_add = builder.AddInstruction(HloInstruction::CreateBinary(
- while0->shape(), HloOpcode::kAdd, gte0, gte1));
-
+ while0->shape(), HloOpcode::kAdd, while0, while1));
module->AddEntryComputation(builder.Build());
+ RunCopyInsertion(module.get());
+
{
FlattenCallGraph flatten;
TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
EXPECT_TRUE(result);
}
- RunCopyInsertion(module.get());
-
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
// To trigger b/38494731, we want a specific Hlo sequence for the
// root computation, so we overwrite that entry with a manually
// crafted sequence.
- sequence[module->entry_computation()] = {
- input1, weights1, one, output1, while1->operand(0), while1,
- input0, weights0, zero, output0, while0->operand(0), while0,
- gte0, gte1, root_add};
+ std::vector<const HloInstruction*> sequence_for_buffer_assigment = {
+ input1, weights1, one, output1, tuple1, while1, input0,
+ weights0, zero, output0, tuple0, while0, root_add};
// If this ASSERT_TRUE fails, we constructed a bogus sequence above
// and this test itself is buggy.
- ASSERT_TRUE(IsPostOrderTraversal(sequence[module->entry_computation()]));
+ ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment));
+
+ sequence[module->entry_computation()] =
+ std::move(sequence_for_buffer_assigment);
auto assignment =
BufferAssigner::Run(
@@ -1714,6 +1715,55 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
+// Test buffer assignment for while nodes with multiple uses.
+// TODO(b/37245345): Fix buffer assignment for this case.
+TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) {
+ auto module = MakeUnique<HloModule>(TestName());
+ auto builder = HloComputation::Builder(TestName());
+
+ auto input0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, data_shape_, "input0"));
+ auto weights0 = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, data_shape_, "weights0"));
+
+ auto zero = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(0.0)));
+ auto output0 = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(data_shape_, zero, {1}));
+
+ auto cond0 =
+ module->AddEmbeddedComputation(BuildWhileConditionComputation("cond"));
+ auto body0 =
+ module->AddEmbeddedComputation(BuildWhileBodyComputation("body"));
+
+ auto tuple0 = builder.AddInstruction(
+ HloInstruction::CreateTuple({input0, weights0, output0}));
+ auto while0 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0));
+ auto while1 = builder.AddInstruction(
+ HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, while0));
+
+ auto get0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, while0, 2));
+ auto get1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(data_shape_, while1, 2));
+ builder.AddInstruction(
+ HloInstruction::CreateBinary(data_shape_, HloOpcode::kAdd, get0, get1));
+ module->AddEntryComputation(builder.Build());
+
+ RunCopyInsertion(module.get());
+
+ {
+ FlattenCallGraph flatten;
+ TF_ASSERT_OK_AND_ASSIGN(bool result, flatten.Run(module.get()));
+ EXPECT_TRUE(result);
+ }
+
+ auto assignment = RunBufferAssignment(module.get());
+
+ EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
+}
+
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto module = MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");