diff options
author | 2017-06-21 13:06:24 -0700 | |
---|---|---|
committer | 2017-06-21 13:10:23 -0700 | |
commit | 9e8005d7771e3f98b0a2ce74e4b0bc3765410a27 (patch) | |
tree | c8e24aa23cbc8aa644047fab7fbbcce96a3a3d82 | |
parent | a24366fa00e5ac0b70c8871d459f5569459329d5 (diff) |
[XLA:HLO] Move sequence functions from hlo_ordering.h to hlo_scheduling.h.
This is required for upcoming changes to convert the sequence creation functions
(and HeapSimulator and BufferAssignment) over to using the new
Hlo{Dataflow,Alias}Analysis.
It's required because otherwise there's a dependency cycle:
Hlo{Dataflow,Alias}Analysis depends on HloOrdering
CreateMemoryMinimizingSequence will depend on Hlo{Dataflow,Alias}Analysis
There's already a cycle here, if both HloOrdering and
CreateMemoryMinimizingSequence are in the same file. Also note that:
MinimumMemoryForSequence depends on HeapSimulator
HeapSimulator will depend on Hlo{Dataflow,Alias}Analysis
Hlo{Dataflow,Alias}Analysis depends on HloOrdering
Splitting out the sequence functions resolves the cycle.
Refactoring only; no functional changes.
PiperOrigin-RevId: 159731836
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 81 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment_test.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/hlo_schedule.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/hlo_schedule.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.cc | 355 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.h | 22 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering_test.cc | 61 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_rematerialization.cc | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling.cc | 388 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling.h | 50 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_scheduling_test.cc | 97 |
15 files changed, 611 insertions, 454 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 778c740b1d..150cd8a678 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -712,9 +712,11 @@ cc_library( ], deps = [ ":buffer_liveness", + ":heap_simulator", ":hlo", ":hlo_ordering", ":hlo_proto", + ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -741,6 +743,7 @@ cc_test( ":flatten_call_graph", ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -753,13 +756,67 @@ cc_test( ], ) +cc_library( + name = "hlo_ordering", + srcs = ["hlo_ordering.cc"], + hdrs = ["hlo_ordering.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_proto", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_ordering_test", + size = "small", + srcs = ["hlo_ordering_test.cc"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_scheduling", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + +cc_library( + name = "heap_simulator", + srcs = ["heap_simulator.cc"], + hdrs = ["heap_simulator.h"], + deps = [ + ":hlo", + ":hlo_ordering", + ":hlo_proto", + ":liveness_util", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + cc_test( name = "heap_simulator_test", size = "small", srcs = ["heap_simulator_test.cc"], deps = [ + ":heap_simulator", ":hlo", ":hlo_ordering", + ":hlo_scheduling", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:literal_util", @@ -770,23 +827,15 @@ cc_test( ], ) -# The hlo_ordering library contains both hlo_ordering and heap_simulator because -# they are mutually dependent. cc_library( - name = "hlo_ordering", - srcs = [ - "heap_simulator.cc", - "hlo_ordering.cc", - ], - hdrs = [ - "heap_simulator.h", - "hlo_ordering.h", - ], + name = "hlo_scheduling", + srcs = ["hlo_scheduling.cc"], + hdrs = ["hlo_scheduling.h"], deps = [ - ":call_graph", + ":heap_simulator", ":hlo", + ":hlo_ordering", ":hlo_proto", - ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -799,12 +848,13 @@ cc_library( ) cc_test( - name = "hlo_ordering_test", + name = "hlo_scheduling_test", size = "small", - srcs = ["hlo_ordering_test.cc"], + srcs = ["hlo_scheduling_test.cc"], deps = [ ":hlo", ":hlo_ordering", + ":hlo_scheduling", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", @@ -1426,6 +1476,7 @@ cc_library( ":hlo", ":hlo_dce", ":hlo_ordering", + ":hlo_scheduling", ":liveness_util", ":logical_buffer", ":tuple_points_to_analysis", diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 44b4f4e3d8..3ba010ac43 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 10021b2513..c498b86dd4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index de6660e3b5..68cd545695 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -68,6 +68,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", "//tensorflow/compiler/xla/service:hlo_proto_util", + "//tensorflow/compiler/xla/service:hlo_scheduling", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:inliner", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 4786e75fa7..0905855ec2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -69,6 +69,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/inliner.h" diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 52b4a13296..1e15ce32ee 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -498,8 +498,9 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_ordering", + "//tensorflow/compiler/xla/service:hlo_scheduling", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index d16a1d4ee5..f76f8ca668 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h index 773973010a..1ce7a48ac8 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.h @@ -19,9 +19,9 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/statusor.h" namespace xla { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 61e5efa5b6..32a2abed92 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -15,13 +15,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_ordering.h" -#include <set> #include <utility> #include <vector> -#include "tensorflow/compiler/xla/service/heap_simulator.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -252,358 +249,6 @@ string SequentialHloOrdering::ToString() const { return tensorflow::str_util::Join(pieces, "\n"); } -StatusOr<int64> MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function) { - if (module_sequence.empty()) { - return 0; - } - - const HloModule* module = module_sequence.begin()->first->parent(); - TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, - TuplePointsToAnalysis::Run(module)); - - // The absolute minimum memory required for a given sequence of instructions - // is determined by the sequence of Alloc and Free calls on a simulated heap, - // ignoring fragmentation. We run the heap simulation on the whole module, - // rather than summing each computation, since it gives us a better lower - // bound, by minimizing the liveness of sub-computations. - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module, - module_sequence, *points_to_analysis, size_function)); - return result.heap_size; -} - -namespace { - -// Class implementing a list scheduler of HLO instructions which produces a -// sequence which minimizes memory usage. -class ListScheduler { - public: - // Construct and return a memory-minimizing sequence of HLO instructions - // containing the given HLO computation. - static StatusOr<std::vector<const HloInstruction*>> Run( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - ListScheduler scheduler(computation, points_to_analysis, size_function); - return scheduler.CreateSchedule(); - } - - private: - // The scheduling priority of an instruction is first the number of bytes - // freed by scheduling the instruction, and second (tie-breaker) by the number - // of users. This is represented as a std::pair containing these two values - // (first element is the bytes freed). std::pair provides the necessary - // comparison operators. - using Priority = std::pair<int64, int64>; - - ListScheduler(const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) - : computation_(computation), - points_to_analysis_(points_to_analysis), - size_function_(size_function) { - // Create a map containing the LogicalBuffer uses for each HLO - // instruction. An HLO instruction "uses" a LogicalBuffer if the - // LogicalBuffer is in an operand of the instruction as indicated by - // points-to analysis. - for (auto& instruction : computation.instructions()) { - buffer_uses_.insert( - {instruction.get(), std::unordered_set<const LogicalBuffer*>()}); - for (auto* operand : instruction->operands()) { - for (const LogicalBuffer* buffer : - points_to_analysis.GetBuffersDefinedByInstruction(operand)) { - buffer_uses_[instruction.get()].insert(buffer); - } - } - } - - // Create map containing the number of unscheduled uses (hlo instructions) - // of each logical buffer. - for (auto& instruction : computation.instructions()) { - for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( - instruction.get())) { - unscheduled_use_count_[buffer] = 0; - } - } - for (auto& instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { - ++unscheduled_use_count_[buffer]; - } - } - - // Buffers live out of the computation have an implicit use at the end of - // the computation. - for (const LogicalBuffer* live_out_buffer : - points_to_analysis.GetPointsToSet(computation.root_instruction()) - .CreateFlattenedSet()) { - ++unscheduled_use_count_[live_out_buffer]; - } - } - - // Returns whether the memory used by the given buffer should be ignored by - // the scheduling heuristic. - bool IgnoreBuffer(const LogicalBuffer& buffer) { - return buffer.instruction()->opcode() == HloOpcode::kParameter || - buffer.instruction()->opcode() == HloOpcode::kConstant; - } - - // Return the number of bytes freed if the HLO instruction is scheduled. - int64 BytesFreedIfScheduled(const HloInstruction* instruction) { - int64 freed_bytes = 0; - // Sum the total size of the values last used by this instruction. - for (auto* buffer : buffer_uses_.at(instruction)) { - if (IgnoreBuffer(*buffer)) { - continue; - } - CHECK_GE(unscheduled_use_count_.at(buffer), 1); - if (unscheduled_use_count_.at(buffer) == 1) { - // This is the last use of the logical buffer. - freed_bytes += size_function_(*buffer); - } - } - // Then subtract the size of the value(s) defined by this instruction. - for (auto* buffer : - points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { - if (!IgnoreBuffer(*buffer)) { - freed_bytes -= size_function_(*buffer); - } - } - return freed_bytes; - } - - // Construct the scheduling priority of the given instruction. - Priority GetPriority(const HloInstruction* instruction) { - return {BytesFreedIfScheduled(instruction), instruction->user_count()}; - } - - std::vector<const HloInstruction*> CreateSchedule() { - std::vector<const HloInstruction*> schedule; - - // Populate the ready list with instructions which have no operands or - // control predecessors. - std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count; - std::list<const HloInstruction*> ready_list; - for (auto& instruction : computation_.instructions()) { - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (const HloInstruction* user : instruction->users()) { - unscheduled_pred_count[user]++; - } - for (const HloInstruction* succ : instruction->control_successors()) { - unscheduled_pred_count[succ]++; - } - } - for (auto& instruction : computation_.instructions()) { - // Instruction with no operands or control predecessors will - // not be in the map. - if (unscheduled_pred_count.count(instruction.get()) == 0) { - ready_list.push_back(instruction.get()); - } - } - - while (!ready_list.empty()) { - // Select the highest priority HLO instruction from the ready list. - auto best_it = ready_list.begin(); - Priority best_priority = GetPriority(*best_it); - for (auto ready_it = std::next(ready_list.begin()); - ready_it != ready_list.end(); ++ready_it) { - Priority priority = GetPriority(*ready_it); - if (priority > best_priority) { - best_it = ready_it; - best_priority = priority; - } - } - - // Remove the selected instruction from the ready list and add it to the - // schedule. - const HloInstruction* best = *best_it; - ready_list.erase(best_it); - schedule.push_back(best); - scheduled_instructions_.insert(best); - - // Update the unscheduled uses of the logical buffers. - for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; - } - - // Add new instructions to ready list. - auto update_pred_count = [&unscheduled_pred_count, - &ready_list](HloInstruction* inst) { - int64 pred_count = --unscheduled_pred_count.at(inst); - CHECK_GE(pred_count, 0); - if (pred_count == 0) { - ready_list.push_back(inst); - } - }; - // TODO(b/34466113): Replace this and above with successors() or - // predecessors() when these methods are added to HloInstruction. - for (HloInstruction* user : best->users()) { - update_pred_count(user); - } - for (HloInstruction* succ : best->control_successors()) { - update_pred_count(succ); - } - } - CHECK_EQ(schedule.size(), computation_.instructions().size()); - CHECK_EQ(scheduled_instructions_.size(), - computation_.instructions().size()); - - return schedule; - } - - const HloComputation& computation_; - const TuplePointsToAnalysis& points_to_analysis_; - const LogicalBuffer::SizeFunction& size_function_; - - // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map<const HloInstruction*, - std::unordered_set<const LogicalBuffer*>> - buffer_uses_; - - // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. - std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; - - // Set of instructions which have been scheduled. - std::unordered_set<const HloInstruction*> scheduled_instructions_; -}; - -int64 SumLogicalBufferSizes(const std::vector<const LogicalBuffer*>& buffers, - const LogicalBuffer::SizeFunction& size_function) { - int64 size = 0; - for (const LogicalBuffer* buffer : buffers) { - size += size_function(*buffer); - } - return size; -} - -StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // This ordering is based on DFS post-order, with a heuristic to decide which - // operand to visit first. The heuristic is based on 'extra_users', which is - // simply users-1 for each instruction. By subtracting 1, we're saying that - // instructions with no users or a single user don't count; instructions with - // lots of fan-out will be visited earlier. - tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users; - tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes; - for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { - extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; - total_sizes[hlo] = SumLogicalBufferSizes( - points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); - tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands( - hlo->operands().begin(), hlo->operands().end()); - for (const HloInstruction* operand : unique_operands) { - extra_users[hlo] += extra_users[operand]; - total_sizes[hlo] += total_sizes[operand]; - } - } - CHECK_EQ(extra_users.size(), computation.instructions().size()); - CHECK_EQ(total_sizes.size(), computation.instructions().size()); - - // Construct a total order based on DFS post-order, visiting operands in - // decreasing cumulative extra user order, and next by cumulative size, with a - // tiebreaker by name for determinism. - std::vector<const HloInstruction*> sequence; - FunctionVisitor visitor([&sequence](HloInstruction* hlo) { - sequence.push_back(hlo); - return Status::OK(); - }); - TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( - &visitor, [&extra_users, &total_sizes](const HloInstruction* a, - const HloInstruction* b) { - if (extra_users[a] != extra_users[b]) { - return extra_users[a] > extra_users[b]; - } - if (total_sizes[a] != total_sizes[b]) { - return total_sizes[a] > total_sizes[b]; - } - return a->name() < b->name(); - })); - CHECK_EQ(sequence.size(), computation.instructions().size()); - return sequence; -} - -StatusOr<int64> MinimumMemoryForComputation( - const HloComputation& computation, - const std::vector<const HloInstruction*>& sequence, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, - sequence, points_to_analysis, size_function)); - return result.heap_size; -} - -StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { - // We try both a list-scheduler based ordering and a DFS based ordering, and - // choose whichever returns a lower min-memory, not accounting for - // fragmentation. - // - // Note that this is just a heuristic. One obvious inaccuracy is that the - // memory required for sub-computations might be different when considered - // within the caller's context. But it's good enough for now. - TF_ASSIGN_OR_RETURN( - std::vector<const HloInstruction*> list_sequence, - ListScheduler::Run(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 list_memory, - MinimumMemoryForComputation(computation, list_sequence, - points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; - - TF_ASSIGN_OR_RETURN( - std::vector<const HloInstruction*> dfs_sequence, - RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); - TF_ASSIGN_OR_RETURN( - const int64 dfs_memory, - MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, - size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; - - if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; - return list_sequence; - } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; - return dfs_sequence; - } -} - -} // namespace - -StatusOr<SequentialHloOrdering::HloModuleSequence> -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { - SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, - TuplePointsToAnalysis::Run(&module)); - for (const auto& computation : module.computations()) { - TF_ASSIGN_OR_RETURN(sequence[computation.get()], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); - } - return sequence; -} - -StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, - TuplePointsToAnalysis::Run(computation.parent())); - return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); -} - std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence) { diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h index b59e1ea5eb..ff84f887f7 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.h +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -24,12 +24,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" -#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/flatmap.h" -#include "tensorflow/core/lib/gtl/flatset.h" namespace xla { @@ -191,24 +187,6 @@ std::ostream& operator<<( std::ostream& out, const SequentialHloOrdering::HloModuleSequence& module_sequence); -// Returns the minimum memory required to compute the given module sequence, -// assuming no fragmentation. -StatusOr<int64> MinimumMemoryForSequence( - const SequentialHloOrdering::HloModuleSequence& module_sequence, - const LogicalBuffer::SizeFunction& size_function); - -// Returns an HloModuleSequence which seeks to minimize the memory required for -// the computation. size_function is the function returning the number of bytes -// required for a LogicalBuffer. -StatusOr<SequentialHloOrdering::HloModuleSequence> -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); - -// Overload of above that computes the sequence for a single computation. -StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( - const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 56e36bd705..a1e38803c4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/types.h" @@ -217,67 +218,7 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) { EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); } -class MinimumMemoryForSequenceTest : public HloTestBase {}; - -TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { - auto module = CreateNewModule(); - const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); - const Shape tuple_shape = - ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); - - auto cond_builder = HloComputation::Builder("WhileCond"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* cond_param = cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); - HloInstruction* cond_iter = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); - HloInstruction* cond_data = cond_builder.AddInstruction( - HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); - // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) - HloInstruction* cond_lt = cond_builder.AddInstruction( - HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), - HloOpcode::kLt, cond_iter, cond_data)); - HloComputation* cond_computation = - module->AddEmbeddedComputation(cond_builder.Build()); - - auto body_builder = HloComputation::Builder("WhileBody"); - // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) - HloInstruction* body_param = body_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "body_param")); - HloComputation* body_computation = - module->AddEmbeddedComputation(body_builder.Build()); - - auto builder = HloComputation::Builder(TestName()); - // Entry params: 8 bytes (4 bytes per param), TOTAL=8 - HloInstruction* iter = builder.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); - HloInstruction* data = builder.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "param_data")); - // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 - HloInstruction* tuple = - builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); - // While: 8 bytes (4 bytes per element), TOTAL=32 - // Both cond and body use a max of 24 bytes, TOTAL=56 - HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, cond_computation, body_computation, tuple)); - HloComputation* entry_computation = - module->AddEntryComputation(builder.Build()); - - auto size_fn = [](const LogicalBuffer& buffer) { - return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); - }; - - SequentialHloOrdering::HloModuleSequence module_sequence; - module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, - cond_lt}; - module_sequence[body_computation] = {body_param}; - module_sequence[entry_computation] = {iter, data, tuple, while_op}; - EXPECT_EQ(56, - MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); -} - } // namespace - } // namespace xla int main(int argc, char** argv) { diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index fb6d8674b6..d19e8034ac 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/status_macros.h" diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc new file mode 100644 index 0000000000..f8e05448da --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -0,0 +1,388 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include <utility> +#include <vector> + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr<int64> MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function) { + if (module_sequence.empty()) { + return 0; + } + + const HloModule* module = module_sequence.begin()->first->parent(); + TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, + TuplePointsToAnalysis::Run(module)); + + // The absolute minimum memory required for a given sequence of instructions + // is determined by the sequence of Alloc and Free calls on a simulated heap, + // ignoring fragmentation. We run the heap simulation on the whole module, + // rather than summing each computation, since it gives us a better lower + // bound, by minimizing the liveness of sub-computations. + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), *module, + module_sequence, *points_to_analysis, size_function)); + return result.heap_size; +} + +namespace { + +// Class implementing a list scheduler of HLO instructions which produces a +// sequence which minimizes memory usage. +class ListScheduler { + public: + // Construct and return a memory-minimizing sequence of HLO instructions + // containing the given HLO computation. + static StatusOr<std::vector<const HloInstruction*>> Run( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); + return scheduler.CreateSchedule(); + } + + private: + // The scheduling priority of an instruction is first the number of bytes + // freed by scheduling the instruction, and second (tie-breaker) by the number + // of users. This is represented as a std::pair containing these two values + // (first element is the bytes freed). std::pair provides the necessary + // comparison operators. + using Priority = std::pair<int64, int64>; + + ListScheduler(const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) + : computation_(computation), + points_to_analysis_(points_to_analysis), + size_function_(size_function) { + // Create a map containing the LogicalBuffer uses for each HLO + // instruction. An HLO instruction "uses" a LogicalBuffer if the + // LogicalBuffer is in an operand of the instruction as indicated by + // points-to analysis. + for (auto& instruction : computation.instructions()) { + buffer_uses_.insert( + {instruction.get(), std::unordered_set<const LogicalBuffer*>()}); + for (auto* operand : instruction->operands()) { + for (const LogicalBuffer* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(operand)) { + buffer_uses_[instruction.get()].insert(buffer); + } + } + } + + // Create map containing the number of unscheduled uses (hlo instructions) + // of each logical buffer. + for (auto& instruction : computation.instructions()) { + for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( + instruction.get())) { + unscheduled_use_count_[buffer] = 0; + } + } + for (auto& instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + ++unscheduled_use_count_[buffer]; + } + } + + // Buffers live out of the computation have an implicit use at the end of + // the computation. + for (const LogicalBuffer* live_out_buffer : + points_to_analysis.GetPointsToSet(computation.root_instruction()) + .CreateFlattenedSet()) { + ++unscheduled_use_count_[live_out_buffer]; + } + } + + // Returns whether the memory used by the given buffer should be ignored by + // the scheduling heuristic. + bool IgnoreBuffer(const LogicalBuffer& buffer) { + return buffer.instruction()->opcode() == HloOpcode::kParameter || + buffer.instruction()->opcode() == HloOpcode::kConstant; + } + + // Return the number of bytes freed if the HLO instruction is scheduled. + int64 BytesFreedIfScheduled(const HloInstruction* instruction) { + int64 freed_bytes = 0; + // Sum the total size of the values last used by this instruction. + for (auto* buffer : buffer_uses_.at(instruction)) { + if (IgnoreBuffer(*buffer)) { + continue; + } + CHECK_GE(unscheduled_use_count_.at(buffer), 1); + if (unscheduled_use_count_.at(buffer) == 1) { + // This is the last use of the logical buffer. + freed_bytes += size_function_(*buffer); + } + } + // Then subtract the size of the value(s) defined by this instruction. + for (auto* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (!IgnoreBuffer(*buffer)) { + freed_bytes -= size_function_(*buffer); + } + } + return freed_bytes; + } + + // Construct the scheduling priority of the given instruction. + Priority GetPriority(const HloInstruction* instruction) { + return {BytesFreedIfScheduled(instruction), instruction->user_count()}; + } + + std::vector<const HloInstruction*> CreateSchedule() { + std::vector<const HloInstruction*> schedule; + + // Populate the ready list with instructions which have no operands or + // control predecessors. + std::unordered_map<const HloInstruction*, int64> unscheduled_pred_count; + std::list<const HloInstruction*> ready_list; + for (auto& instruction : computation_.instructions()) { + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (const HloInstruction* user : instruction->users()) { + unscheduled_pred_count[user]++; + } + for (const HloInstruction* succ : instruction->control_successors()) { + unscheduled_pred_count[succ]++; + } + } + for (auto& instruction : computation_.instructions()) { + // Instruction with no operands or control predecessors will + // not be in the map. + if (unscheduled_pred_count.count(instruction.get()) == 0) { + ready_list.push_back(instruction.get()); + } + } + + while (!ready_list.empty()) { + // Select the highest priority HLO instruction from the ready list. + auto best_it = ready_list.begin(); + Priority best_priority = GetPriority(*best_it); + for (auto ready_it = std::next(ready_list.begin()); + ready_it != ready_list.end(); ++ready_it) { + Priority priority = GetPriority(*ready_it); + if (priority > best_priority) { + best_it = ready_it; + best_priority = priority; + } + } + + // Remove the selected instruction from the ready list and add it to the + // schedule. + const HloInstruction* best = *best_it; + ready_list.erase(best_it); + schedule.push_back(best); + scheduled_instructions_.insert(best); + + // Update the unscheduled uses of the logical buffers. + for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { + CHECK_GT(unscheduled_use_count_.at(buffer), 0); + --unscheduled_use_count_[buffer]; + } + + // Add new instructions to ready list. + auto update_pred_count = [&unscheduled_pred_count, + &ready_list](HloInstruction* inst) { + int64 pred_count = --unscheduled_pred_count.at(inst); + CHECK_GE(pred_count, 0); + if (pred_count == 0) { + ready_list.push_back(inst); + } + }; + // TODO(b/34466113): Replace this and above with successors() or + // predecessors() when these methods are added to HloInstruction. + for (HloInstruction* user : best->users()) { + update_pred_count(user); + } + for (HloInstruction* succ : best->control_successors()) { + update_pred_count(succ); + } + } + CHECK_EQ(schedule.size(), computation_.instructions().size()); + CHECK_EQ(scheduled_instructions_.size(), + computation_.instructions().size()); + + return schedule; + } + + const HloComputation& computation_; + const TuplePointsToAnalysis& points_to_analysis_; + const LogicalBuffer::SizeFunction& size_function_; + + // A map containing the LogicalBuffers that each instruction uses. + std::unordered_map<const HloInstruction*, + std::unordered_set<const LogicalBuffer*>> + buffer_uses_; + + // A map containing the count of unscheduled HLOs which using a particular + // LogicalBuffer. + std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; + + // Set of instructions which have been scheduled. + std::unordered_set<const HloInstruction*> scheduled_instructions_; +}; + +int64 SumLogicalBufferSizes(const std::vector<const LogicalBuffer*>& buffers, + const LogicalBuffer::SizeFunction& size_function) { + int64 size = 0; + for (const LogicalBuffer* buffer : buffers) { + size += size_function(*buffer); + } + return size; +} + +StatusOr<std::vector<const HloInstruction*>> RunDFSMemoryScheduler( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // This ordering is based on DFS post-order, with a heuristic to decide which + // operand to visit first. The heuristic is based on 'extra_users', which is + // simply users-1 for each instruction. By subtracting 1, we're saying that + // instructions with no users or a single user don't count; instructions with + // lots of fan-out will be visited earlier. + tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users; + tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes; + for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) { + extra_users[hlo] = hlo->users().empty() ? 0 : hlo->users().size() - 1; + total_sizes[hlo] = SumLogicalBufferSizes( + points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function); + tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands( + hlo->operands().begin(), hlo->operands().end()); + for (const HloInstruction* operand : unique_operands) { + extra_users[hlo] += extra_users[operand]; + total_sizes[hlo] += total_sizes[operand]; + } + } + CHECK_EQ(extra_users.size(), computation.instructions().size()); + CHECK_EQ(total_sizes.size(), computation.instructions().size()); + + // Construct a total order based on DFS post-order, visiting operands in + // decreasing cumulative extra user order, and next by cumulative size, with a + // tiebreaker by name for determinism. + std::vector<const HloInstruction*> sequence; + FunctionVisitor visitor([&sequence](HloInstruction* hlo) { + sequence.push_back(hlo); + return Status::OK(); + }); + TF_RETURN_IF_ERROR(computation.AcceptWithOperandOrder( + &visitor, [&extra_users, &total_sizes](const HloInstruction* a, + const HloInstruction* b) { + if (extra_users[a] != extra_users[b]) { + return extra_users[a] > extra_users[b]; + } + if (total_sizes[a] != total_sizes[b]) { + return total_sizes[a] > total_sizes[b]; + } + return a->name() < b->name(); + })); + CHECK_EQ(sequence.size(), computation.instructions().size()); + return sequence; +} + +StatusOr<int64> MinimumMemoryForComputation( + const HloComputation& computation, + const std::vector<const HloInstruction*>& sequence, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN( + HeapSimulator::Result result, + HeapSimulator::Run(MakeUnique<NoFragmentationStatsHeap>(), computation, + sequence, points_to_analysis, size_function)); + return result.heap_size; +} + +StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + // We try both a list-scheduler based ordering and a DFS based ordering, and + // choose whichever returns a lower min-memory, not accounting for + // fragmentation. + // + // Note that this is just a heuristic. One obvious inaccuracy is that the + // memory required for sub-computations might be different when considered + // within the caller's context. But it's good enough for now. + TF_ASSIGN_OR_RETURN( + std::vector<const HloInstruction*> list_sequence, + ListScheduler::Run(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 list_memory, + MinimumMemoryForComputation(computation, list_sequence, + points_to_analysis, size_function)); + VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + + TF_ASSIGN_OR_RETURN( + std::vector<const HloInstruction*> dfs_sequence, + RunDFSMemoryScheduler(computation, points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + const int64 dfs_memory, + MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, + size_function)); + VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + + if (list_memory <= dfs_memory) { + VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + return list_sequence; + } else { + VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + return dfs_sequence; + } +} + +} // namespace + +StatusOr<SequentialHloOrdering::HloModuleSequence> +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, + TuplePointsToAnalysis::Run(&module)); + for (const auto& computation : module.computations()) { + TF_ASSIGN_OR_RETURN(sequence[computation.get()], + CreateMemoryMinimizingSequence( + *computation, *points_to_analysis, size_function)); + } + return sequence; +} + +StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, + TuplePointsToAnalysis::Run(computation.parent())); + return CreateMemoryMinimizingSequence(computation, *points_to_analysis, + size_function); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h new file mode 100644 index 0000000000..ec92a56b96 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -0,0 +1,50 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ + +#include <vector> + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" + +namespace xla { + +// Returns the minimum memory required to compute the given module sequence, +// assuming no fragmentation. +StatusOr<int64> MinimumMemoryForSequence( + const SequentialHloOrdering::HloModuleSequence& module_sequence, + const LogicalBuffer::SizeFunction& size_function); + +// Returns an HloModuleSequence which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr<SequentialHloOrdering::HloModuleSequence> +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + +// Overload of above that computes the sequence for a single computation. +StatusOr<std::vector<const HloInstruction*>> CreateMemoryMinimizingSequence( + const HloComputation& computation, + const LogicalBuffer::SizeFunction& size_function); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SCHEDULING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling_test.cc b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc new file mode 100644 index 0000000000..d09d22ee40 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_scheduling_test.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" + +#include <memory> +#include <string> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class MinimumMemoryForSequenceTest : public HloTestBase {}; + +TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({scalar_shape, scalar_shape}); + + auto cond_builder = HloComputation::Builder("WhileCond"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "cond_param")); + HloInstruction* cond_iter = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 0)); + HloInstruction* cond_data = cond_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape, cond_param, 1)); + // Free cond_param[] (16 bytes), Alloc PRED[] (1 byte) + HloInstruction* cond_lt = cond_builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), + HloOpcode::kLt, cond_iter, cond_data)); + HloComputation* cond_computation = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("WhileBody"); + // Tuple param: 24 bytes (each elem has 8 byte pointer, 4 byte element) + HloInstruction* body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "body_param")); + HloComputation* body_computation = + module->AddEmbeddedComputation(body_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + // Entry params: 8 bytes (4 bytes per param), TOTAL=8 + HloInstruction* iter = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param_iter")); + HloInstruction* data = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param_data")); + // Tuple: 16 bytes (8 bytes per pointer), TOTAL=24 + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({iter, data})); + // While: 8 bytes (4 bytes per element), TOTAL=32 + // Both cond and body use a max of 24 bytes, TOTAL=56 + HloInstruction* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + tuple_shape, cond_computation, body_computation, tuple)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto size_fn = [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape(), /*pointer_size=*/8); + }; + + SequentialHloOrdering::HloModuleSequence module_sequence; + module_sequence[cond_computation] = {cond_param, cond_iter, cond_data, + cond_lt}; + module_sequence[body_computation] = {body_param}; + module_sequence[entry_computation] = {iter, data, tuple, while_op}; + EXPECT_EQ(56, + MinimumMemoryForSequence(module_sequence, size_fn).ValueOrDie()); +} + +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + return xla::ParseDebugOptionsFlagsAndRunTests(argc, argv); +} |