diff options
author | 2018-02-26 14:19:56 -0800 | |
---|---|---|
committer | 2018-02-26 14:23:37 -0800 | |
commit | d98e7fc5720c1597b6f2034ba2ad62438ac5ef39 (patch) | |
tree | 3e9063ca7a9ce572b73475508b5a4060f6f887d3 /tensorflow/compiler/xla/service/heap_simulator_test.cc | |
parent | a05488be720fc803ac56738c8bc0222fb8a36d7f (diff) |
[XLA] GTE of a certain element of the tuple does not need not keep other elements alive.
This achieves two things:
1. Heap simulation runtime is no longer quadratic in the number of tuple elements (as we don't add each GetTupleElement to the liveset of each buffer defined by the tuple).
2. A reduction in the heap memory footprint.
PiperOrigin-RevId: 187079787
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/heap_simulator_test.cc | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 387b649a73..688a271712 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -410,6 +410,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramB = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32scalar_, "paramB")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kMultiply, paramA, paramB)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kAdd, paramA, paramB)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + auto element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kSubtract, paramA, paramB)); + auto element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); + auto output = builder.AddInstruction( + HloInstruction::CreateTuple({broadcast, sub, element1})); + + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramB, mul, add, tuple, element0, + broadcast, sub, element1, output}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramB, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(add, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(broadcast, {})}, + // The mul can be freed right after the broadcast happens, even though + // The other GetTupleElement is still alive. + {kFree, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(sub, {})}, + // The temporary tuple is now dead. + {kFree, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(output, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramB, {})}, + {kFree, tracker.BufferAt(add, {})}, + {kFree, tracker.BufferAt(broadcast, {})}, + {kFree, tracker.BufferAt(sub, {})}, + {kFree, tracker.BufferAt(output, {})}, + {kFinish, nullptr}, + }); +} + TEST_F(HeapSimulatorTest, WholeModule) { HeapSimulatorTracker tracker(TestName()); |