aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/heap_simulator_test.cc
diff options
context:
space:
mode:
authorGravatar Dimitris Vardoulakis <dimvar@google.com>2018-06-12 11:52:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 11:54:33 -0700
commite47701d1d30c744b8bffc263b640c401d611bc0e (patch)
tree5131b9235bdb450e541ee51a058387720a93e52a /tensorflow/compiler/xla/service/heap_simulator_test.cc
parentc5436b90adff058500e88b497fc4f7a0b0379d28 (diff)
[TF:XLA] Move methods MinimumMemoryFor... from hlo_scheduling to heap_simulator.
These methods have nothing to do with scheduling. Also, rename methods CreateMemoryMinimizingSequence in hlo_scheduling. PiperOrigin-RevId: 200254100
Diffstat (limited to 'tensorflow/compiler/xla/service/heap_simulator_test.cc')
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 6271652412..309ab85f78 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -34,6 +34,64 @@ limitations under the License.
namespace xla {
namespace {
+class MinimumMemoryForSequenceTest : public HloTestBase {};
+
+TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
+
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
+ HloInstruction* cond_iter = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
+ HloInstruction* cond_data = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
+ // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
+ HloInstruction* cond_lt = cond_builder.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kLt, cond_iter, cond_data));
+ HloComputation* cond_computation =
+ module->AddEmbeddedComputation(cond_builder.Build());
+
+ auto body_builder = HloComputation::Builder("WhileBody");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
+ HloComputation* body_computation =
+ module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ // Entry params: 8 bytes (4 bytes per param), TOTAL=8
+ HloInstruction* iter = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
+ HloInstruction* data = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
+ // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
+ // While: 8 bytes (4 bytes per element), TOTAL=32
+ // Both cond and body use a max of 24 bytes, TOTAL=56
+ HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+ tuple_shape, cond_computation, body_computation, tuple));
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const BufferValue& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ };
+
+ SequentialHloOrdering::HloModuleSequence module_sequence;
+ module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
+ cond_lt};
+ module_sequence[body_computation] = {body_param};
+ module_sequence[entry_computation] = {iter, data, tuple, while_op};
+ EXPECT_EQ(56, MinimumMemoryForModule(module_sequence, size_fn).ValueOrDie());
+}
+
const char kAlloc[] = "Alloc";
const char kFree[] = "Free";
const char kFinish[] = "Finish";