aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-21 13:06:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-21 13:10:23 -0700
commit9e8005d7771e3f98b0a2ce74e4b0bc3765410a27 (patch)
treec8e24aa23cbc8aa644047fab7fbbcce96a3a3d82
parenta24366fa00e5ac0b70c8871d459f5569459329d5 (diff)
[XLA:HLO] Move sequence functions from hlo_ordering.h to hlo_scheduling.h.
This is required for upcoming changes to convert the sequence creation functions (and HeapSimulator and BufferAssignment) over to using the new Hlo{Dataflow,Alias}Analysis. It's required because otherwise there's a dependency cycle: Hlo{Dataflow,Alias}Analysis depends on HloOrdering CreateMemoryMinimizingSequence will depend on Hlo{Dataflow,Alias}Analysis There's already a cycle here, if both HloOrdering and CreateMemoryMinimizingSequence are in the same file. Also note that: MinimumMemoryForSequence depends on HeapSimulator HeapSimulator will depend on Hlo{Dataflow,Alias}Analysis Hlo{Dataflow,Alias}Analysis depends on HloOrdering Splitting out the sequence functions resolves the cycle. Refactoring only; no functional changes. PiperOrigin-RevId: 159731836
-rw-r--r--tensorflow/compiler/xla/service/BUILD81
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc1
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.cc1
-rw-r--r--tensorflow/compiler/xla/service/gpu/hlo_schedule.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc355
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc61
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.cc388
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling.h50
-rw-r--r--tensorflow/compiler/xla/service/hlo_scheduling_test.cc97
15 files changed, 611 insertions, 454 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 778c740b1d..150cd8a678 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -712,9 +712,11 @@ cc_library(
],
deps = [
":buffer_liveness",
+ ":heap_simulator",
":hlo",
":hlo_ordering",
":hlo_proto",
+ ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -741,6 +743,7 @@ cc_test(
":flatten_call_graph",
":hlo",
":hlo_ordering",
+ ":hlo_scheduling",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
@@ -753,13 +756,67 @@ cc_test(
],
)
+cc_library(
+ name = "hlo_ordering",
+ srcs = ["hlo_ordering.cc"],
+ hdrs = ["hlo_ordering.h"],
+ deps = [
+ ":call_graph",
+ ":hlo",
+ ":hlo_proto",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_test(
+ name = "hlo_ordering_test",
+ size = "small",
+ srcs = ["hlo_ordering_test.cc"],
+ deps = [
+ ":hlo",
+ ":hlo_ordering",
+ ":hlo_scheduling",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ ],
+)
+
+cc_library(
+ name = "heap_simulator",
+ srcs = ["heap_simulator.cc"],
+ hdrs = ["heap_simulator.h"],
+ deps = [
+ ":hlo",
+ ":hlo_ordering",
+ ":hlo_proto",
+ ":liveness_util",
+ ":logical_buffer",
+ ":tuple_points_to_analysis",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_test(
name = "heap_simulator_test",
size = "small",
srcs = ["heap_simulator_test.cc"],
deps = [
+ ":heap_simulator",
":hlo",
":hlo_ordering",
+ ":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:literal_util",
@@ -770,23 +827,15 @@ cc_test(
],
)
-# The hlo_ordering library contains both hlo_ordering and heap_simulator because
-# they are mutually dependent.
cc_library(
- name = "hlo_ordering",
- srcs = [
- "heap_simulator.cc",
- "hlo_ordering.cc",
- ],
- hdrs = [
- "heap_simulator.h",
- "hlo_ordering.h",
- ],
+ name = "hlo_scheduling",
+ srcs = ["hlo_scheduling.cc"],
+ hdrs = ["hlo_scheduling.h"],
deps = [
- ":call_graph",
+ ":heap_simulator",
":hlo",
+ ":hlo_ordering",
":hlo_proto",
- ":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:shape_util",
@@ -799,12 +848,13 @@ cc_library(
)
cc_test(
- name = "hlo_ordering_test",
+ name = "hlo_scheduling_test",
size = "small",
- srcs = ["hlo_ordering_test.cc"],
+ srcs = ["hlo_scheduling_test.cc"],
deps = [
":hlo",
":hlo_ordering",
+ ":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
@@ -1426,6 +1476,7 @@ cc_library(
":hlo",
":hlo_dce",
":hlo_ordering",
+ ":hlo_scheduling",
":liveness_util",
":logical_buffer",
":tuple_points_to_analysis",
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 44b4f4e3d8..3ba010ac43 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 10021b2513..c498b86dd4 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index de6660e3b5..68cd545695 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -68,6 +68,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/service:hlo_proto_util",
+ "//tensorflow/compiler/xla/service:hlo_scheduling",
"//tensorflow/compiler/xla/service:hlo_subcomputation_unification",
"//tensorflow/compiler/xla/service:hlo_verifier",
"//tensorflow/compiler/xla/service:inliner",
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index 4786e75fa7..0905855ec2 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -69,6 +69,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/service/inliner.h"
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 52b4a13296..1e15ce32ee 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -498,8 +498,9 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_ordering",
+ "//tensorflow/compiler/xla/service:hlo_scheduling",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
index d16a1d4ee5..f76f8ca668 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h"
#include "tensorflow/compiler/xla/ptr_util.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
index 773973010a..1ce7a48ac8 100644
--- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h
@@ -19,9 +19,9 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/gpu/stream_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 61e5efa5b6..32a2abed92 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -15,13 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
-#include <set>
#include <utility>
#include <vector>
-#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -252,358 +249,6 @@ string SequentialHloOrdering::ToString() const {
return tensorflow::str_util::Join(pieces, "\n");
}
-StatusOr<int64> MinimumMemoryForSequence(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function) {
- if (module_sequence.empty()) {
- return 0;
- }
-
- const HloModule* module = module_sequence.begin()->first->parent();
- TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
- TuplePointsToAnalysis::Run(module));
-
- // The absolute minimum memory required for a given sequence of instructions
- // is determined by the sequence of Alloc and Free calls on a simulated heap,
- // ignoring fragmentation. We run the heap simulation on the whole module,
- // rather than summing each computation, since it gives us a better lower
- // bound, by minimizing the liveness of sub-computations.
- TF_ASSIGN_OR_RETURN(
- HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
- module_sequence, *points_to_analysis, size_function));
- return result.heap_size;
-}
-
-namespace {
-
-// Class implementing a list scheduler of HLO instructions which produces a
-// sequence which minimizes memory usage.
-class ListScheduler {
- public:
- // Construct and return a memory-minimizing sequence of HLO instructions
- // containing the given HLO computation.
- static StatusOr<std::vector<const HloInstruction*>> Run(
- const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
- ListScheduler scheduler(computation, points_to_analysis, size_function);
- return scheduler.CreateSchedule();
- }
-
- private:
- // The scheduling priority of an instruction is first the number of bytes
- // freed by scheduling the instruction, and second (tie-breaker) by the number
- // of users. This is represented as a std::pair containing these two values
- // (first element is the bytes freed). std::pair provides the necessary
- // comparison operators.
- using Priority = std::pair<int64, int64>;
-
- ListScheduler(const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function)
- : computation_(computation),
- points_to_analysis_(points_to_analysis),
- size_function_(size_function) {
- // Create a map containing the LogicalBuffer uses for each HLO
- // instruction. An HLO instruction "uses" a LogicalBuffer if the
- // LogicalBuffer is in an operand of the instruction as indicated by
- // points-to analysis.
- for (auto& instruction : computation.instructions()) {
- buffer_uses_.insert(
- {instruction.get(), std::unordered_set<const LogicalBuffer*>()});
- for (auto* operand : instruction->operands()) {
- for (const LogicalBuffer* buffer :
- points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
- buffer_uses_[instruction.get()].insert(buffer);
- }
- }
- }
-
- // Create map containing the number of unscheduled uses (hlo instructions)
- // of each logical buffer.
- for (auto& instruction : computation.instructions()) {
- for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(
- instruction.get())) {
- unscheduled_use_count_[buffer] = 0;
- }
- }
- for (auto& instruction : computation.instructions()) {
- for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) {
- ++unscheduled_use_count_[buffer];
- }
- }
-
- // Buffers live out of the computation have an implicit use at the end of
- // the computation.
- for (const LogicalBuffer* live_out_buffer :
- points_to_analysis.GetPointsToSet(computation.root_instruction())
- .CreateFlattenedSet()) {
- ++unscheduled_use_count_[live_out_buffer];
- }
- }
-
- // Returns whether the memory used by the given buffer should be ignored by
- // the scheduling heuristic.
- bool IgnoreBuffer(const LogicalBuffer& buffer) {
- return buffer.instruction()->opcode() == HloOpcode::kParameter ||
- buffer.instruction()->opcode() == HloOpcode::kConstant;
- }
-
- // Return the number of bytes freed if the HLO instruction is scheduled.
- int64 BytesFreedIfScheduled(const HloInstruction* instruction) {
- int64 freed_bytes = 0;
- // Sum the total size of the values last used by this instruction.
- for (auto* buffer : buffer_uses_.at(instruction)) {
- if (IgnoreBuffer(*buffer)) {
- continue;
- }
- CHECK_GE(unscheduled_use_count_.at(buffer), 1);
- if (unscheduled_use_count_.at(buffer) == 1) {
- // This is the last use of the logical buffer.
- freed_bytes += size_function_(*buffer);
- }
- }
- // Then subtract the size of the value(s) defined by this instruction.
- for (auto* buffer :
- points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
- if (!IgnoreBuffer(*buffer)) {
- freed_bytes -= size_function_(*buffer);
- }
- }
- return freed_bytes;
- }
-
- // Construct the scheduling priority of the given instruction.
- Priority GetPriority(const HloInstruction* instruction) {
- return {BytesFreedIfScheduled(instruction), instruction->user_count()};
- }
-
- std::vector<const HloInstruction*> CreateSchedule() {
- std::vector<const HloInstruction*> schedule;
-
- // Populate the ready list with instructions which have no operands or
- // control predecessors.
- std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
- std::list<const HloInstruction*> ready_list;
- for (auto& instruction : computation_.instructions()) {
- // TODO(b/34466113): Replace this and above with successors() or
- // predecessors() when these methods are added to HloInstruction.
- for (const HloInstruction* user : instruction->users()) {
- unscheduled_pred_count[user]++;
- }
- for (const HloInstruction* succ : instruction->control_successors()) {
- unscheduled_pred_count[succ]++;
- }
- }
- for (auto& instruction : computation_.instructions()) {
- // Instruction with no operands or control predecessors will
- // not be in the map.
- if (unscheduled_pred_count.count(instruction.get()) == 0) {
- ready_list.push_back(instruction.get());
- }
- }
-
- while (!ready_list.empty()) {
- // Select the highest priority HLO instruction from the ready list.
- auto best_it = ready_list.begin();
- Priority best_priority = GetPriority(*best_it);
- for (auto ready_it = std::next(ready_list.begin());
- ready_it != ready_list.end(); ++ready_it) {
- Priority priority = GetPriority(*ready_it);
- if (priority > best_priority) {
- best_it = ready_it;
- best_priority = priority;
- }
- }
-
- // Remove the selected instruction from the ready list and add it to the
- // schedule.
- const HloInstruction* best = *best_it;
- ready_list.erase(best_it);
- schedule.push_back(best);
- scheduled_instructions_.insert(best);
-
- // Update the unscheduled uses of the logical buffers.
- for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
- CHECK_GT(unscheduled_use_count_.at(buffer), 0);
- --unscheduled_use_count_[buffer];
- }
-
- // Add new instructions to ready list.
- auto update_pred_count = [&unscheduled_pred_count,
- &ready_list](HloInstruction* inst) {
- int64 pred_count = --unscheduled_pred_count.at(inst);
- CHECK_GE(pred_count, 0);
- if (pred_count == 0) {
- ready_list.push_back(inst);
- }
- };
- // TODO(b/34466113): Replace this and above with successors() or
- // predecessors() when these methods are added to HloInstruction.
- for (HloInstruction* user : best->users()) {
- update_pred_count(user);
- }
- for (HloInstruction* succ : best->control_successors()) {
- update_pred_count(succ);
- }
- }
- CHECK_EQ(schedule.size(), computation_.instructions().size());
- CHECK_EQ(scheduled_instructions_.size(),
- computation_.instructions().size());
-
- return schedule;
- }
-
- const HloComputation& computation_;
- const TuplePointsToAnalysis& points_to_analysis_;
- const LogicalBuffer::SizeFunction& size_function_;
-
- // A map containing the LogicalBuffers that each instruction uses.
- std::unordered_map<const HloInstruction*,
- std::unordered_set<const LogicalBuffer*>>
- buffer_uses_;
-
- // A map containing the count of unscheduled HLOs which using a particular
- // LogicalBuffer.
- std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
-
- // Set of instructions which have been scheduled.
- std::unordered_set<const HloInstruction*> scheduled_instructions_;
-};
-
-int64 SumLogicalBufferSizes(const std::vector<const LogicalBuffer*>& buffers,
- const LogicalBuffer::SizeFunction& size_function) {
- int64 size = 0;
- for (const LogicalBuffer* buffer : buffers) {
- size += size_function(*buffer);
- }
- return size;
-}
-
-StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
- const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
- // This ordering is based on DFS post-order, with a heuristic to decide which
- // operand to visit first. The heuristic is based on 'extra_users', which is
- // simply users-1 for each instruction. By subtracting 1, we're saying that
- // instructions with no users or a single user don't count; instructions with
- // lots of fan-out will be visited earlier.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
- for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
- extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
- total_sizes[hlo] = SumLogicalBufferSizes(
- points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
- tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
- hlo->operands().begin(), hlo->operands().end());
- for (const HloInstruction* operand : unique_operands) {
- extra_users[hlo] += extra_users[operand];
- total_sizes[hlo] += total_sizes[operand];
- }
- }
- CHECK_EQ(extra_users.size(), computation.instructions().size());
- CHECK_EQ(total_sizes.size(), computation.instructions().size());
-
- // Construct a total order based on DFS post-order, visiting operands in
- // decreasing cumulative extra user order, and next by cumulative size, with a
- // tiebreaker by name for determinism.
- std::vector<const HloInstruction*> sequence;
- FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
- sequence.push_back(hlo);
- return Status::OK();
- });
- TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
- &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
- const HloInstruction* b) {
- if (extra_users[a] != extra_users[b]) {
- return extra_users[a] > extra_users[b];
- }
- if (total_sizes[a] != total_sizes[b]) {
- return total_sizes[a] > total_sizes[b];
- }
- return a->name() < b->name();
- }));
- CHECK_EQ(sequence.size(), computation.instructions().size());
- return sequence;
-}
-
-StatusOr<int64> MinimumMemoryForComputation(
- const HloComputation& computation,
- const std::vector<const HloInstruction*>& sequence,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
- TF_ASSIGN_OR_RETURN(
- HeapSimulator::Result result,
- HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
- sequence, points_to_analysis, size_function));
- return result.heap_size;
-}
-
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
- const HloComputation& computation,
- const TuplePointsToAnalysis& points_to_analysis,
- const LogicalBuffer::SizeFunction& size_function) {
- // We try both a list-scheduler based ordering and a DFS based ordering, and
- // choose whichever returns a lower min-memory, not accounting for
- // fragmentation.
- //
- // Note that this is just a heuristic. One obvious inaccuracy is that the
- // memory required for sub-computations might be different when considered
- // within the caller's context. But it's good enough for now.
- TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> list_sequence,
- ListScheduler::Run(computation, points_to_analysis, size_function));
- TF_ASSIGN_OR_RETURN(
- const int64 list_memory,
- MinimumMemoryForComputation(computation, list_sequence,
- points_to_analysis, size_function));
- VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes";
-
- TF_ASSIGN_OR_RETURN(
- std::vector<const HloInstruction*> dfs_sequence,
- RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
- TF_ASSIGN_OR_RETURN(
- const int64 dfs_memory,
- MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
- size_function));
- VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes";
-
- if (list_memory <= dfs_memory) {
- VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes";
- return list_sequence;
- } else {
- VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes";
- return dfs_sequence;
- }
-}
-
-} // namespace
-
-StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(
- const HloModule& module, const LogicalBuffer::SizeFunction& size_function) {
- SequentialHloOrdering::HloModuleSequence sequence;
- TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
- TuplePointsToAnalysis::Run(&module));
- for (const auto& computation : module.computations()) {
- TF_ASSIGN_OR_RETURN(sequence[computation.get()],
- CreateMemoryMinimizingSequence(
- *computation, *points_to_analysis, size_function));
- }
- return sequence;
-}
-
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
- const HloComputation& computation,
- const LogicalBuffer::SizeFunction& size_function) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
- TuplePointsToAnalysis::Run(computation.parent()));
- return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
- size_function);
-}
-
std::ostream& operator<<(
std::ostream& out,
const SequentialHloOrdering::HloModuleSequence& module_sequence) {
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b59e1ea5eb..ff84f887f7 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -24,12 +24,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
-#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -191,24 +187,6 @@ std::ostream& operator<<(
std::ostream& out,
const SequentialHloOrdering::HloModuleSequence& module_sequence);
-// Returns the minimum memory required to compute the given module sequence,
-// assuming no fragmentation.
-StatusOr<int64> MinimumMemoryForSequence(
- const SequentialHloOrdering::HloModuleSequence& module_sequence,
- const LogicalBuffer::SizeFunction& size_function);
-
-// Returns an HloModuleSequence which seeks to minimize the memory required for
-// the computation. size_function is the function returning the number of bytes
-// required for a LogicalBuffer.
-StatusOr<SequentialHloOrdering::HloModuleSequence>
-CreateMemoryMinimizingSequence(
- const HloModule& module, const LogicalBuffer::SizeFunction& size_function);
-
-// Overload of above that computes the sequence for a single computation.
-StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
- const HloComputation& computation,
- const LogicalBuffer::SizeFunction& size_function);
-
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 56e36bd705..a1e38803c4 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
@@ -217,67 +218,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
-class MinimumMemoryForSequenceTest : public HloTestBase {};
-
-TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
- auto module = CreateNewModule();
- const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
- const Shape tuple_shape =
- ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
-
- auto cond_builder = HloComputation::Builder("WhileCond");
- // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
- HloInstruction* cond_param = cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
- HloInstruction* cond_iter = cond_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
- HloInstruction* cond_data = cond_builder.AddInstruction(
- HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
- // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
- HloInstruction* cond_lt = cond_builder.AddInstruction(
- HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
- HloOpcode::kLt, cond_iter, cond_data));
- HloComputation* cond_computation =
- module->AddEmbeddedComputation(cond_builder.Build());
-
- auto body_builder = HloComputation::Builder("WhileBody");
- // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
- HloInstruction* body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
- HloComputation* body_computation =
- module->AddEmbeddedComputation(body_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- // Entry params: 8 bytes (4 bytes per param), TOTAL=8
- HloInstruction* iter = builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
- HloInstruction* data = builder.AddInstruction(
- HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
- // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
- HloInstruction* tuple =
- builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
- // While: 8 bytes (4 bytes per element), TOTAL=32
- // Both cond and body use a max of 24 bytes, TOTAL=56
- HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
- tuple_shape, cond_computation, body_computation, tuple));
- HloComputation* entry_computation =
- module->AddEntryComputation(builder.Build());
-
- auto size_fn = [](const LogicalBuffer& buffer) {
- return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
- };
-
- SequentialHloOrdering::HloModuleSequence module_sequence;
- module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
- cond_lt};
- module_sequence[body_computation] = {body_param};
- module_sequence[entry_computation] = {iter, data, tuple, while_op};
- EXPECT_EQ(56,
- MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
-}
-
} // namespace
-
} // namespace xla
int main(int argc, char** argv) {
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index fb6d8674b6..d19e8034ac 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/status_macros.h"
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc
new file mode 100644
index 0000000000..f8e05448da
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc
@@ -0,0 +1,388 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+
+#include <utility>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/heap_simulator.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+StatusOr<int64> MinimumMemoryForSequence(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function) {
+ if (module_sequence.empty()) {
+ return 0;
+ }
+
+ const HloModule* module = module_sequence.begin()->first->parent();
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+
+ // The absolute minimum memory required for a given sequence of instructions
+ // is determined by the sequence of Alloc and Free calls on a simulated heap,
+ // ignoring fragmentation. We run the heap simulation on the whole module,
+ // rather than summing each computation, since it gives us a better lower
+ // bound, by minimizing the liveness of sub-computations.
+ TF_ASSIGN_OR_RETURN(
+ HeapSimulator::Result result,
+ HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module,
+ module_sequence, *points_to_analysis, size_function));
+ return result.heap_size;
+}
+
+namespace {
+
+// Class implementing a list scheduler of HLO instructions which produces a
+// sequence which minimizes memory usage.
+class ListScheduler {
+ public:
+ // Construct and return a memory-minimizing sequence of HLO instructions
+ // containing the given HLO computation.
+ static StatusOr<std::vector<const HloInstruction*>> Run(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ ListScheduler scheduler(computation, points_to_analysis, size_function);
+ return scheduler.CreateSchedule();
+ }
+
+ private:
+ // The scheduling priority of an instruction is first the number of bytes
+ // freed by scheduling the instruction, and second (tie-breaker) by the number
+ // of users. This is represented as a std::pair containing these two values
+ // (first element is the bytes freed). std::pair provides the necessary
+ // comparison operators.
+ using Priority = std::pair<int64, int64>;
+
+ ListScheduler(const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function)
+ : computation_(computation),
+ points_to_analysis_(points_to_analysis),
+ size_function_(size_function) {
+ // Create a map containing the LogicalBuffer uses for each HLO
+ // instruction. An HLO instruction "uses" a LogicalBuffer if the
+ // LogicalBuffer is in an operand of the instruction as indicated by
+ // points-to analysis.
+ for (auto& instruction : computation.instructions()) {
+ buffer_uses_.insert(
+ {instruction.get(), std::unordered_set<const LogicalBuffer*>()});
+ for (auto* operand : instruction->operands()) {
+ for (const LogicalBuffer* buffer :
+ points_to_analysis.GetBuffersDefinedByInstruction(operand)) {
+ buffer_uses_[instruction.get()].insert(buffer);
+ }
+ }
+ }
+
+ // Create map containing the number of unscheduled uses (hlo instructions)
+ // of each logical buffer.
+ for (auto& instruction : computation.instructions()) {
+ for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction(
+ instruction.get())) {
+ unscheduled_use_count_[buffer] = 0;
+ }
+ }
+ for (auto& instruction : computation.instructions()) {
+ for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) {
+ ++unscheduled_use_count_[buffer];
+ }
+ }
+
+ // Buffers live out of the computation have an implicit use at the end of
+ // the computation.
+ for (const LogicalBuffer* live_out_buffer :
+ points_to_analysis.GetPointsToSet(computation.root_instruction())
+ .CreateFlattenedSet()) {
+ ++unscheduled_use_count_[live_out_buffer];
+ }
+ }
+
+ // Returns whether the memory used by the given buffer should be ignored by
+ // the scheduling heuristic.
+ bool IgnoreBuffer(const LogicalBuffer& buffer) {
+ return buffer.instruction()->opcode() == HloOpcode::kParameter ||
+ buffer.instruction()->opcode() == HloOpcode::kConstant;
+ }
+
+ // Return the number of bytes freed if the HLO instruction is scheduled.
+ int64 BytesFreedIfScheduled(const HloInstruction* instruction) {
+ int64 freed_bytes = 0;
+ // Sum the total size of the values last used by this instruction.
+ for (auto* buffer : buffer_uses_.at(instruction)) {
+ if (IgnoreBuffer(*buffer)) {
+ continue;
+ }
+ CHECK_GE(unscheduled_use_count_.at(buffer), 1);
+ if (unscheduled_use_count_.at(buffer) == 1) {
+ // This is the last use of the logical buffer.
+ freed_bytes += size_function_(*buffer);
+ }
+ }
+ // Then subtract the size of the value(s) defined by this instruction.
+ for (auto* buffer :
+ points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) {
+ if (!IgnoreBuffer(*buffer)) {
+ freed_bytes -= size_function_(*buffer);
+ }
+ }
+ return freed_bytes;
+ }
+
+ // Construct the scheduling priority of the given instruction.
+ Priority GetPriority(const HloInstruction* instruction) {
+ return {BytesFreedIfScheduled(instruction), instruction->user_count()};
+ }
+
+ std::vector<const HloInstruction*> CreateSchedule() {
+ std::vector<const HloInstruction*> schedule;
+
+ // Populate the ready list with instructions which have no operands or
+ // control predecessors.
+ std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count;
+ std::list<const HloInstruction*> ready_list;
+ for (auto& instruction : computation_.instructions()) {
+ // TODO(b/34466113): Replace this and above with successors() or
+ // predecessors() when these methods are added to HloInstruction.
+ for (const HloInstruction* user : instruction->users()) {
+ unscheduled_pred_count[user]++;
+ }
+ for (const HloInstruction* succ : instruction->control_successors()) {
+ unscheduled_pred_count[succ]++;
+ }
+ }
+ for (auto& instruction : computation_.instructions()) {
+ // Instruction with no operands or control predecessors will
+ // not be in the map.
+ if (unscheduled_pred_count.count(instruction.get()) == 0) {
+ ready_list.push_back(instruction.get());
+ }
+ }
+
+ while (!ready_list.empty()) {
+ // Select the highest priority HLO instruction from the ready list.
+ auto best_it = ready_list.begin();
+ Priority best_priority = GetPriority(*best_it);
+ for (auto ready_it = std::next(ready_list.begin());
+ ready_it != ready_list.end(); ++ready_it) {
+ Priority priority = GetPriority(*ready_it);
+ if (priority > best_priority) {
+ best_it = ready_it;
+ best_priority = priority;
+ }
+ }
+
+ // Remove the selected instruction from the ready list and add it to the
+ // schedule.
+ const HloInstruction* best = *best_it;
+ ready_list.erase(best_it);
+ schedule.push_back(best);
+ scheduled_instructions_.insert(best);
+
+ // Update the unscheduled uses of the logical buffers.
+ for (const LogicalBuffer* buffer : buffer_uses_.at(best)) {
+ CHECK_GT(unscheduled_use_count_.at(buffer), 0);
+ --unscheduled_use_count_[buffer];
+ }
+
+ // Add new instructions to ready list.
+ auto update_pred_count = [&unscheduled_pred_count,
+ &ready_list](HloInstruction* inst) {
+ int64 pred_count = --unscheduled_pred_count.at(inst);
+ CHECK_GE(pred_count, 0);
+ if (pred_count == 0) {
+ ready_list.push_back(inst);
+ }
+ };
+ // TODO(b/34466113): Replace this and above with successors() or
+ // predecessors() when these methods are added to HloInstruction.
+ for (HloInstruction* user : best->users()) {
+ update_pred_count(user);
+ }
+ for (HloInstruction* succ : best->control_successors()) {
+ update_pred_count(succ);
+ }
+ }
+ CHECK_EQ(schedule.size(), computation_.instructions().size());
+ CHECK_EQ(scheduled_instructions_.size(),
+ computation_.instructions().size());
+
+ return schedule;
+ }
+
+ const HloComputation& computation_;
+ const TuplePointsToAnalysis& points_to_analysis_;
+ const LogicalBuffer::SizeFunction& size_function_;
+
+ // A map containing the LogicalBuffers that each instruction uses.
+ std::unordered_map<const HloInstruction*,
+ std::unordered_set<const LogicalBuffer*>>
+ buffer_uses_;
+
+ // A map containing the count of unscheduled HLOs which using a particular
+ // LogicalBuffer.
+ std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
+
+ // Set of instructions which have been scheduled.
+ std::unordered_set<const HloInstruction*> scheduled_instructions_;
+};
+
+int64 SumLogicalBufferSizes(const std::vector<const LogicalBuffer*>& buffers,
+ const LogicalBuffer::SizeFunction& size_function) {
+ int64 size = 0;
+ for (const LogicalBuffer* buffer : buffers) {
+ size += size_function(*buffer);
+ }
+ return size;
+}
+
+StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ // This ordering is based on DFS post-order, with a heuristic to decide which
+ // operand to visit first. The heuristic is based on 'extra_users', which is
+ // simply users-1 for each instruction. By subtracting 1, we're saying that
+ // instructions with no users or a single user don't count; instructions with
+ // lots of fan-out will be visited earlier.
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
+ tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
+ for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
+ extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1;
+ total_sizes[hlo] = SumLogicalBufferSizes(
+ points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
+ tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
+ hlo->operands().begin(), hlo->operands().end());
+ for (const HloInstruction* operand : unique_operands) {
+ extra_users[hlo] += extra_users[operand];
+ total_sizes[hlo] += total_sizes[operand];
+ }
+ }
+ CHECK_EQ(extra_users.size(), computation.instructions().size());
+ CHECK_EQ(total_sizes.size(), computation.instructions().size());
+
+ // Construct a total order based on DFS post-order, visiting operands in
+ // decreasing cumulative extra user order, and next by cumulative size, with a
+ // tiebreaker by name for determinism.
+ std::vector<const HloInstruction*> sequence;
+ FunctionVisitor visitor([&sequence](HloInstruction* hlo) {
+ sequence.push_back(hlo);
+ return Status::OK();
+ });
+ TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder(
+ &visitor, [&extra_users, &total_sizes](const HloInstruction* a,
+ const HloInstruction* b) {
+ if (extra_users[a] != extra_users[b]) {
+ return extra_users[a] > extra_users[b];
+ }
+ if (total_sizes[a] != total_sizes[b]) {
+ return total_sizes[a] > total_sizes[b];
+ }
+ return a->name() < b->name();
+ }));
+ CHECK_EQ(sequence.size(), computation.instructions().size());
+ return sequence;
+}
+
+StatusOr<int64> MinimumMemoryForComputation(
+ const HloComputation& computation,
+ const std::vector<const HloInstruction*>& sequence,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ TF_ASSIGN_OR_RETURN(
+ HeapSimulator::Result result,
+ HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation,
+ sequence, points_to_analysis, size_function));
+ return result.heap_size;
+}
+
+StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+ const HloComputation& computation,
+ const TuplePointsToAnalysis& points_to_analysis,
+ const LogicalBuffer::SizeFunction& size_function) {
+ // We try both a list-scheduler based ordering and a DFS based ordering, and
+ // choose whichever returns a lower min-memory, not accounting for
+ // fragmentation.
+ //
+ // Note that this is just a heuristic. One obvious inaccuracy is that the
+ // memory required for sub-computations might be different when considered
+ // within the caller's context. But it's good enough for now.
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const HloInstruction*> list_sequence,
+ ListScheduler::Run(computation, points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(
+ const int64 list_memory,
+ MinimumMemoryForComputation(computation, list_sequence,
+ points_to_analysis, size_function));
+ VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes";
+
+ TF_ASSIGN_OR_RETURN(
+ std::vector<const HloInstruction*> dfs_sequence,
+ RunDFSMemoryScheduler(computation, points_to_analysis, size_function));
+ TF_ASSIGN_OR_RETURN(
+ const int64 dfs_memory,
+ MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis,
+ size_function));
+ VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes";
+
+ if (list_memory <= dfs_memory) {
+ VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes";
+ return list_sequence;
+ } else {
+ VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes";
+ return dfs_sequence;
+ }
+}
+
+} // namespace
+
+StatusOr<SequentialHloOrdering::HloModuleSequence>
+CreateMemoryMinimizingSequence(
+ const HloModule& module, const LogicalBuffer::SizeFunction& size_function) {
+ SequentialHloOrdering::HloModuleSequence sequence;
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(&module));
+ for (const auto& computation : module.computations()) {
+ TF_ASSIGN_OR_RETURN(sequence[computation.get()],
+ CreateMemoryMinimizingSequence(
+ *computation, *points_to_analysis, size_function));
+ }
+ return sequence;
+}
+
+StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+ const HloComputation& computation,
+ const LogicalBuffer::SizeFunction& size_function) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
+ TuplePointsToAnalysis::Run(computation.parent()));
+ return CreateMemoryMinimizingSequence(computation, *points_to_analysis,
+ size_function);
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h
new file mode 100644
index 0000000000..ec92a56b96
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_scheduling.h
@@ -0,0 +1,50 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
+
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/service/logical_buffer.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+
+// Returns the minimum memory required to compute the given module sequence,
+// assuming no fragmentation.
+StatusOr<int64> MinimumMemoryForSequence(
+ const SequentialHloOrdering::HloModuleSequence& module_sequence,
+ const LogicalBuffer::SizeFunction& size_function);
+
+// Returns an HloModuleSequence which seeks to minimize the memory required for
+// the computation. size_function is the function returning the number of bytes
+// required for a LogicalBuffer.
+StatusOr<SequentialHloOrdering::HloModuleSequence>
+CreateMemoryMinimizingSequence(
+ const HloModule& module, const LogicalBuffer::SizeFunction& size_function);
+
+// Overload of above that computes the sequence for a single computation.
+StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence(
+ const HloComputation& computation,
+ const LogicalBuffer::SizeFunction& size_function);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_
diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
new file mode 100644
index 0000000000..d09d22ee40
--- /dev/null
+++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
+
+#include <memory>
+#include <string>
+
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_ordering.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/types.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+namespace {
+
+class MinimumMemoryForSequenceTest : public HloTestBase {};
+
+TEST_F(MinimumMemoryForSequenceTest, MultiComputation) {
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ const Shape tuple_shape =
+ ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape});
+
+ auto cond_builder = HloComputation::Builder("WhileCond");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "cond_param"));
+ HloInstruction* cond_iter = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0));
+ HloInstruction* cond_data = cond_builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1));
+ // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte)
+ HloInstruction* cond_lt = cond_builder.AddInstruction(
+ HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}),
+ HloOpcode::kLt, cond_iter, cond_data));
+ HloComputation* cond_computation =
+ module->AddEmbeddedComputation(cond_builder.Build());
+
+ auto body_builder = HloComputation::Builder("WhileBody");
+ // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element)
+ HloInstruction* body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "body_param"));
+ HloComputation* body_computation =
+ module->AddEmbeddedComputation(body_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ // Entry params: 8 bytes (4 bytes per param), TOTAL=8
+ HloInstruction* iter = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param_iter"));
+ HloInstruction* data = builder.AddInstruction(
+ HloInstruction::CreateParameter(1, scalar_shape, "param_data"));
+ // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({iter, data}));
+ // While: 8 bytes (4 bytes per element), TOTAL=32
+ // Both cond and body use a max of 24 bytes, TOTAL=56
+ HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile(
+ tuple_shape, cond_computation, body_computation, tuple));
+ HloComputation* entry_computation =
+ module->AddEntryComputation(builder.Build());
+
+ auto size_fn = [](const LogicalBuffer& buffer) {
+ return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8);
+ };
+
+ SequentialHloOrdering::HloModuleSequence module_sequence;
+ module_sequence[cond_computation] = {cond_param, cond_iter, cond_data,
+ cond_lt};
+ module_sequence[body_computation] = {body_param};
+ module_sequence[entry_computation] = {iter, data, tuple, while_op};
+ EXPECT_EQ(56,
+ MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie());
+}
+
+} // namespace
+} // namespace xla
+
+int main(int argc, char** argv) {
+ return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv);
+}