aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator_test.cc
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-02-26 14:19:56 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 14:23:37 -0800
commitd98e7fc5720c1597b6f2034ba2ad62438ac5ef39 (patch)
tree3e9063ca7a9ce572b73475508b5a4060f6f887d3 /tensorflow/compiler/xla/service/heap_simulator_test.cc
parenta05488be720fc803ac56738c8bc0222fb8a36d7f (diff)
[XLA] GTE of a certain element of the tuple does not need not keep other elements alive.
This achieves two things: 1. Heap simulation runtime is no longer quadratic in the number of tuple elements (as we don't add each GetTupleElement to the liveset of each buffer defined by the tuple). 2. A reduction in the heap memory footprint. PiperOrigin-RevId: 187079787
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc50
1 files changed, 50 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 387b649a73..688a271712 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -410,6 +410,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
});
}
+TEST_F(HeapSimulatorTest, IndependentTupleElements) {
+ auto builder = HloComputation::Builder(TestName());
+ auto paramA = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, f32scalar_, "paramA"));
+ auto paramB = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, f32scalar_, "paramB"));
+ auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32scalar_, HloOpcode::kMultiply, paramA, paramB));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32scalar_, HloOpcode::kAdd, paramA, paramB));
+ auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add}));
+ auto element0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0));
+ auto broadcast = builder.AddInstruction(
+ HloInstruction::CreateBroadcast(f32vec4_, element0, {0}));
+ auto sub = builder.AddInstruction(HloInstruction::CreateBinary(
+ f32scalar_, HloOpcode::kSubtract, paramA, paramB));
+ auto element1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1));
+ auto output = builder.AddInstruction(
+ HloInstruction::CreateTuple({broadcast, sub, element1}));
+
+ HeapSimulatorTracker tracker(TestName(), builder.Build(),
+ {paramA, paramB, mul, add, tuple, element0,
+ broadcast, sub, element1, output});
+ tracker.ExpectCallSequence({
+ {kAlloc, tracker.BufferAt(paramA, {})},
+ {kAlloc, tracker.BufferAt(paramB, {})},
+ {kAlloc, tracker.BufferAt(mul, {})},
+ {kAlloc, tracker.BufferAt(add, {})},
+ {kAlloc, tracker.BufferAt(tuple, {})},
+ {kAlloc, tracker.BufferAt(broadcast, {})},
+ // The mul can be freed right after the broadcast happens, even though
+ // The other GetTupleElement is still alive.
+ {kFree, tracker.BufferAt(mul, {})},
+ {kAlloc, tracker.BufferAt(sub, {})},
+ // The temporary tuple is now dead.
+ {kFree, tracker.BufferAt(tuple, {})},
+ {kAlloc, tracker.BufferAt(output, {})},
+ // All params and outputs are freed at the end.
+ {kFree, tracker.BufferAt(paramA, {})},
+ {kFree, tracker.BufferAt(paramB, {})},
+ {kFree, tracker.BufferAt(add, {})},
+ {kFree, tracker.BufferAt(broadcast, {})},
+ {kFree, tracker.BufferAt(sub, {})},
+ {kFree, tracker.BufferAt(output, {})},
+ {kFinish, nullptr},
+ });
+}
+
TEST_F(HeapSimulatorTest, WholeModule) {
HeapSimulatorTracker tracker(TestName());