diff options
author | Mark Heffernan <meheff@google.com> | 2017-01-25 14:49:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-25 15:03:50 -0800 |
commit | 239493a6825f33c96d64b6a36be6616fbb41e42b (patch) | |
tree | 4d1dac37af8ebd79b79741a44dc144ae7b1afc72 | |
parent | 59757b7afd5dd08d5651ca966f03511bb2aad7bd (diff) |
Break out HloOrdering classes into separate files.
Add CreateMemoryMinimizingSequence which constructs a sequence of the
instructions in an HLO module that heuristically minimizes the
total size of live buffers containing HLO outputs.
Change: 145599747
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 37 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.cc | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_assignment.h | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_liveness.cc | 103 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/buffer_liveness.h | 126 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/cpu_compiler.cc | 39 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.cc | 24 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_computation.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 72 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 15 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.cc | 363 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.h | 172 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering_test.cc | 83 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/logical_buffer.h | 3 |
15 files changed, 754 insertions, 300 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4526431f59..41add6f7d2 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -445,6 +445,7 @@ cc_library( ], deps = [ ":hlo", + ":hlo_ordering", ":logical_buffer", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:shape_util", @@ -516,6 +517,42 @@ cc_test( ) cc_library( + name = "hlo_ordering", + srcs = [ + "hlo_ordering.cc", + ], + hdrs = [ + "hlo_ordering.h", + ], + deps = [ + ":hlo", + ":logical_buffer", + ":tuple_points_to_analysis", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "hlo_ordering_test", + srcs = ["hlo_ordering_test.cc"], + deps = [ + ":cpu_plugin", + ":hlo", + ":hlo_ordering", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) + +cc_library( name = "hlo_query", srcs = ["hlo_query.cc"], hdrs = ["hlo_query.h"], diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 33365ffa53..07ff323a3d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -316,7 +316,7 @@ tensorflow::Status GatherComputationsByAllocationType( /* static */ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run( const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, - BufferSizeFunction buffer_size, bool colocate_related_buffers, + LogicalBuffer::SizeFunction buffer_size, bool colocate_related_buffers, const std::vector<const HloInstruction*>* hlos_to_allocate) { BufferAssigner assigner(std::move(buffer_size), colocate_related_buffers); return assigner.CreateAssignment(module, std::move(hlo_ordering), @@ -779,7 +779,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment( root->shape(), [this, &output_size, root, &assignment]( const Shape& /*subshape*/, const ShapeIndex& index) { const auto& allocations = assignment->GetAllocations(root, index); - if (allocations.size() > 0) { + if (!allocations.empty()) { output_size += allocations.begin()->size(); } return tensorflow::Status::OK(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 46c2253e13..b484ea51b1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -286,10 +286,9 @@ class BufferAssigner { // will be colocated in the same allocation (i.e buffers for while result // will share an allocation with buffers related to that same while // instruction: init operand, condition/body parameter and body result). - using BufferSizeFunction = std::function<int64(const LogicalBuffer&)>; static StatusOr<std::unique_ptr<BufferAssignment>> Run( const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering, - BufferSizeFunction buffer_size, bool colocate_related_buffers, + LogicalBuffer::SizeFunction buffer_size, bool colocate_related_buffers, const std::vector<const HloInstruction*>* hlos_to_allocate = nullptr); // Overload of Run which uses ShapeUtil::ByteSizeOf to determine buffer size @@ -299,7 +298,7 @@ class BufferAssigner { int64 pointer_size); private: - explicit BufferAssigner(BufferSizeFunction buffer_size, + explicit BufferAssigner(LogicalBuffer::SizeFunction buffer_size, bool colocate_related_buffers) : buffer_size_(std::move(buffer_size)), colocate_related_buffers_(colocate_related_buffers) {} @@ -356,7 +355,7 @@ class BufferAssigner { const HloModule* module_; // Function which returns the buffer size for a given shape. - BufferSizeFunction buffer_size_; + LogicalBuffer::SizeFunction buffer_size_; // Indicates whether related buffers should share the same buffer allocation. const bool colocate_related_buffers_; diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index d4faca7cd8..c788c64306 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -35,109 +35,6 @@ limitations under the License. namespace xla { -PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) - : module_(module) {} - -bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { - return false; - } - // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. - return strict_predecessors_.at(b->parent())->IsReachable(b, a); -} - -string PredecessorHloOrdering::ToStringHelper(const string& name) const { - std::vector<string> pieces; - pieces.push_back(name); - for (auto& computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s:", - computation->name().c_str())); - const auto all = computation->MakeInstructionPostOrder(); - for (auto instruction : all) { - pieces.push_back(tensorflow::strings::Printf( - " %s strict predecessors:", instruction->name().c_str())); - for (auto predecessor : all) { - if (strict_predecessors_.at(computation.get()) - ->IsReachable(instruction, predecessor)) { - pieces.push_back( - tensorflow::strings::Printf(" %s", predecessor->name().c_str())); - } - } - } - } - return tensorflow::str_util::Join(pieces, "\n"); -} - -DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) - : PredecessorHloOrdering(module) { - // Compute predecessor relationships between all instructions to determine - // ordering based on dependencies. ExecutesBefore will return true iff there - // exists a path in the HLO computation graph from 'a' to 'b'. - for (auto& computation : module->computations()) { - strict_predecessors_.emplace(computation.get(), - computation->ComputeTransitiveOperands()); - } -} - -string DependencyHloOrdering::ToString() const { - return ToStringHelper("DependencyHloOrdering"); -} - -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : module_(module) { - // Create a map from instruction to its order position. - for (auto computation_order : module_sequence) { - const std::vector<const HloInstruction*>& order = computation_order.second; - for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); - } - } -} - -bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const { - // Instructions in different computations are unordered. - if (a->parent() != b->parent()) { - return false; - } - // If either instruction is not in the order, then 'a' and 'b' are unordered. - if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { - return false; - } - return order_position_.at(a) < order_position_.at(b); -} - -string SequentialHloOrdering::ToString() const { - std::vector<string> pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto& computation : module_->computations()) { - pieces.push_back(tensorflow::strings::Printf("computation %s order:", - computation->name().c_str())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector<const HloInstruction*> instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation.get()) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back( - tensorflow::strings::Printf(" %s", instruction->name().c_str())); - } - } - return tensorflow::str_util::Join(pieces, "\n"); -} - /* static */ StatusOr<std::unique_ptr<BufferLiveness>> BufferLiveness::Run( const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering) { diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h index 964f558c8c..b9e7a2a28d 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.h +++ b/tensorflow/compiler/xla/service/buffer_liveness.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -31,131 +32,6 @@ limitations under the License. namespace xla { -// Abstract base class for describing a partial ordering of HLO -// instructions. Used to determine live range overlap of HLO instruction output -// buffers. -class HloOrdering { - public: - HloOrdering() = default; - virtual ~HloOrdering() = default; - - // Returns true if instruction 'a' executes before instruction 'b'. This is - // not reflexive, that is, an instruction does not execute before itself. - virtual bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const = 0; - virtual string ToString() const = 0; -}; - -// Base class for partial orderings implemented by a map of strict predecessors -// for each instruction. Subclasses should fill in strict_predecessors_. -class PredecessorHloOrdering : public HloOrdering { - public: - ~PredecessorHloOrdering() override = default; - - // Returns true if instruction 'a' executes before instruction 'b'. - // Instructions in different computations are not ordered. - bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const override; - - protected: - explicit PredecessorHloOrdering(const HloModule* module); - string ToStringHelper(const string& name) const; - - const HloModule* module_; - - // For each each computation in the module, this is the set of the - // instruction's strict predecessors. An instruction is not an element of its - // own strict predecessor set. - // - // Subclasses should fill this in to define the desired ordering. - tensorflow::gtl::FlatMap<const HloComputation*, - std::unique_ptr<HloComputation::ReachabilityMap>> - strict_predecessors_; -}; - -// An HLO ordering based on data dependencies in the HLO graph. In this partial -// order, instruction A executes before instruction B only if there is a path -// from A to B in the HLO graph. For example, given the following graph: -// -// param -// / \ -// negate exp -// \ / -// add -// -// DependencyHloOrdering gives the following executes-before relations: -// param executes before negate, exp, and add -// negate executes before add -// exp executes before add -// add executes before nothing -// negate and exp are not ordered because the dependencies allow either to -// execute before the other (or in parallel). DependencyHloOrdering ordering -// allows maximum parallelism and enables any execution order which satisfies -// data dependencies. This requires pessimistic assumptions about buffer live -// ranges and can result in more memory used than more constrained orderings. -class DependencyHloOrdering : public PredecessorHloOrdering { - public: - explicit DependencyHloOrdering(const HloModule* module); - ~DependencyHloOrdering() override = default; - - string ToString() const override; -}; - -// An HLO ordering based on a total order of instructions in each computation. -// The computation total order is a sequencing of all of its instructions in -// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded -// execution. For example, given the following HLO graph: -// -// param -// / \ -// negate exp -// \ / -// add -// -// and the following sequence: -// -// {param, negate, exp, add} -// -// SequentialHloOrdering gives the following executes-before relations: -// param executes before negate, exp, and add -// negate executes before exp and add -// exp executes before add -// add executes before nothing -// This is more constrained than DependencyHloOrdering in this example because -// negate and exp are ordered (negate before exp). This enables param to share -// the same buffer as exp (param buffer is dead after exp). Generally, this -// ordering enables more buffer sharing (reduced memory usage) because buffer -// interference is reduced relative to DependencyHloOrdering. -class SequentialHloOrdering : public HloOrdering { - public: - // A sequence of instructions for each computation in the module. - using HloModuleSequence = - tensorflow::gtl::FlatMap<const HloComputation*, - std::vector<const HloInstruction*>>; - - SequentialHloOrdering(const HloModule* module, - const HloModuleSequence& module_sequence); - ~SequentialHloOrdering() override = default; - - // Instruction 'a' executes before 'b' if 'a' appears before 'b' in the - // instruction sequence for the computation. Instructions in different - // computations are unordered. - bool ExecutesBefore(const HloInstruction* a, - const HloInstruction* b) const override; - string ToString() const override; - - protected: - const HloModule* module_; - - // The position of every instruction in the HLO module in its respective - // computation sequence (a value of zero indicates the instruction is first in - // the sequence, etc). Instructions from all computations are contained in - // this map so more than one instruction may have the same position - // value. This is not a problem because ExecutesBefore also verifies - // instructions are in the same computation. - tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_; -}; - // Class which computes liveness of the output buffers of HLOs and their // interference. class BufferLiveness { diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 8af54b11bb..ad78cb3a52 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -63,6 +63,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 74df5429e6..4a3934cff7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -62,6 +62,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" #include "tensorflow/compiler/xla/service/hlo_subcomputation_unification.h" @@ -270,27 +271,6 @@ llvm::CodeGenOpt::Level CodeGenOptLevel() { } } -// Constructs and returns a sequence for the HLO instructions in each -// computation in the given module. The sequence can be used to determine the -// order of HLO instruction emission and for buffer liveness analysis. -SequentialHloOrdering::HloModuleSequence CreateModuleSequence( - const HloModule* module) { - SequentialHloOrdering::HloModuleSequence sequence; - for (auto& computation : module->computations()) { - // Do a DFS traversal from the root to construct a sequence for each - // computation. - // TODO(b/32006145): Construct a sequence to minimize memory pressure. - std::vector<const HloInstruction*> order; - TF_CHECK_OK(computation->root_instruction()->Accept( - [&order](HloInstruction* instruction) { - order.push_back(instruction); - return Status::OK(); - })); - sequence.emplace(computation.get(), std::move(order)); - } - return sequence; -} - } // namespace StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( @@ -412,8 +392,12 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile( // Select an order for emitting the HLO instructions for each // computation. Using this sequence enables tighter buffer liveness analysis // and reduced memory usage (as compared to using DependencyHloOrdering). - SequentialHloOrdering::HloModuleSequence module_sequence = - CreateModuleSequence(hlo_module.get()); + TF_ASSIGN_OR_RETURN( + SequentialHloOrdering::HloModuleSequence module_sequence, + CreateMemoryMinimizingSequence( + *hlo_module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -559,8 +543,13 @@ CpuCompiler::CompileAheadOfTime( TF_RETURN_IF_ERROR(RunHloPasses(hlo_module, module_config, dump_hlo)); - SequentialHloOrdering::HloModuleSequence module_sequence = - CreateModuleSequence(hlo_module); + TF_ASSIGN_OR_RETURN( + SequentialHloOrdering::HloModuleSequence module_sequence, + CreateMemoryMinimizingSequence( + *hlo_module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 40fec80f88..5c6ca80acf 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -524,6 +524,30 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); } +Status HloComputation::AcceptOrdered( + DfsHloVisitor* visitor, + const std::vector<const HloInstruction*>& order) const { + TF_RET_CHECK(order.size() == instruction_count()); + std::unordered_set<const HloInstruction*> visited; + for (const HloInstruction* instruction : order) { + TF_RET_CHECK(instruction_iterators_.count(instruction) == 1) + << "Instruction " << instruction->name() << " is not in computation " + << name(); + TF_RET_CHECK(visited.count(instruction) == 0) + << "Instruction " << instruction->name() + << " appears more than once in order"; + HloInstruction* mutable_instruction = + const_cast<HloInstruction*>(instruction); + TF_RETURN_IF_ERROR(visitor->Preprocess(mutable_instruction)); + TF_RETURN_IF_ERROR(mutable_instruction->Visit(visitor)); + visitor->SetVisited(*mutable_instruction); + TF_RETURN_IF_ERROR(visitor->Postprocess(mutable_instruction)); + visited.insert(instruction); + } + TF_RETURN_IF_ERROR(visitor->FinishVisit(root_instruction())); + return Status::OK(); +} + Status HloComputation::Accept( const FunctionVisitor::VisitorFunction& visitor_func) const { FunctionVisitor visitor(visitor_func); diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 2c3cceddc8..e78e86b91f 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -214,6 +214,11 @@ class HloComputation { // root instruction as the argument). Status Accept(DfsHloVisitor* visitor) const; + // Visit every node in the computation in the given order. 'order' must + // be a topological sort of all instructions in the computation. + Status AcceptOrdered(DfsHloVisitor* visitor, + const std::vector<const HloInstruction*>& order) const; + // Same as Accept() above, but the visitor is given as a function. Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1880c01e5c..d173f66a97 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1158,7 +1158,8 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, << ShapeUtil::HumanString(new_producer->shape()); auto user_it = std::find(users_.begin(), users_.end(), user); TF_RET_CHECK(user_it != users_.end()) - << "Instruction " << user << " not a use of instruction " << this; + << "Instruction " << user->name() << " not a use of instruction " + << name(); users_.erase(user_it); VLOG(3) << "Replacing uses of " << name() << " in " << user->name() @@ -1360,6 +1361,15 @@ string HloInstruction::ToString(bool compact_operands) const { tensorflow::strings::StrAppend(&extra, ", padding=", padding_config_->ShortDebugString()); } + if (!slice_starts_.empty() && !slice_limits_.empty()) { + std::vector<string> bounds; + for (int i = 0; i < slice_starts_.size(); ++i) { + bounds.push_back(tensorflow::strings::StrCat("[", slice_starts_[i], ":", + slice_limits_[i], "]")); + } + tensorflow::strings::StrAppend( + &extra, ", slice={", tensorflow::str_util::Join(bounds, ", "), "}"); + } if (convolution_dimension_numbers_ != nullptr) { tensorflow::strings::StrAppend( &extra, @@ -1470,7 +1480,7 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } -Status HloInstruction::AcceptInternalVisit(DfsHloVisitor* visitor) { +Status HloInstruction::Visit(DfsHloVisitor* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this, operands_[0]); @@ -1624,19 +1634,20 @@ Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor) { } for (auto control_predecessor : control_predecessors_) { - VLOG(3) << "Going to visit HLO " << control_predecessor - << " as a control predecessor of HLO " << this; + VLOG(3) << "Going to visit HLO " << control_predecessor->name() + << " as a control predecessor of HLO " << name(); TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal(visitor)); } TF_RETURN_IF_ERROR(visitor->Preprocess(this)); - VLOG(3) << "Visiting HLO " << name(); - TF_RETURN_IF_ERROR(AcceptInternalVisit(visitor)); + VLOG(2) << "Visiting HLO " << name(); + TF_RETURN_IF_ERROR(Visit(visitor)); visitor->SetVisited(*this); return visitor->Postprocess(this); } Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) { + VLOG(2) << "HloInstruction::Accept(" << name() << ")"; auto status = AcceptInternal(visitor); if (!status.ok()) { return status; @@ -1651,14 +1662,13 @@ Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit) { namespace { -// Returns true if the given order is a topological sort of exactly those -// instructions rooted at 'root'. -bool OrderIsTopologicalSort(HloInstruction* root, - const std::vector<const HloInstruction*>& order) { +// Returns true if the given order is a topological sort of the instructions it +// contains. +bool OrderIsTopologicalSort(const std::vector<const HloInstruction*>& order) { // Create a map from instruction to its position in 'order'. std::unordered_map<const HloInstruction*, int> order_position; for (int i = 0; i < order.size(); i++) { - if (!order_position.insert(std::make_pair(order[i], i)).second) { + if (!order_position.insert({order[i], i}).second) { // Instruction order[i] is duplicated in the order. return false; } @@ -1675,26 +1685,6 @@ bool OrderIsTopologicalSort(HloInstruction* root, } } - // Create a vector of all instructions in a DFS search starting at - // root. 'order' should contain exactly these instructions. - std::vector<const HloInstruction*> visited; - TF_CHECK_OK(root->Accept([&visited](HloInstruction* instruction) { - visited.push_back(instruction); - return Status::OK(); - })); - - if (order_position.size() != visited.size()) { - return false; - } - for (auto* instruction : visited) { - if (order_position.count(instruction) == 0) { - return false; - } - } - // Given the conditions above, the last element of order should always be the - // root. - CHECK_EQ(root, order[order.size() - 1]); - return true; } @@ -1707,8 +1697,22 @@ Status HloInstruction::Accept(FunctionVisitor::VisitorFunction visitor_func) { Status HloInstruction::AcceptOrdered( DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order) { - DCHECK(OrderIsTopologicalSort(this, order)); + VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")"; + TF_RET_CHECK(OrderIsTopologicalSort(order)); + + // Compute the predecessors of this instruction. + std::unordered_set<const HloInstruction*> predecessors; + TF_RETURN_IF_ERROR(this->Accept([&predecessors](HloInstruction* instruction) { + predecessors.insert(instruction); + return Status::OK(); + })); + for (auto* const_instruction : order) { + if (predecessors.count(const_instruction) == 0) { + // Instruction is not a predecessors of 'this'. + continue; + } + // The visitor can mark instructions as visited to skip particular // instructions. if (visitor->DidVisit(*const_instruction)) { @@ -1721,8 +1725,8 @@ Status HloInstruction::AcceptOrdered( const_cast<HloInstruction*>(const_instruction); TF_RETURN_IF_ERROR(visitor->Preprocess(instruction)); - VLOG(3) << "Visiting HLO " << instruction->name(); - TF_RETURN_IF_ERROR(instruction->AcceptInternalVisit(visitor)); + VLOG(2) << "Visiting HLO " << instruction->name(); + TF_RETURN_IF_ERROR(instruction->Visit(visitor)); visitor->SetVisited(*instruction); TF_RETURN_IF_ERROR(visitor->Postprocess(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index ba1192cf7e..808dfeb246 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -320,7 +320,6 @@ class HloInstruction { // Adds the given instruction to the set of control successors. void AddControlSuccessor(HloInstruction* instruction); - // Returns the set of control successors of this instruction. // Returns true if "other" performs the same computation as this instruction. // Layout of the instructions' output array is not considered. bool Identical( @@ -364,13 +363,18 @@ class HloInstruction { Status Accept(FunctionVisitor::VisitorFunction visitor_func); // Visits all instructions rooted at this instruction using the given visitor - // in the given order. 'order' must contain exactly the set of instructions + // in the given order. 'order' must contain at least the set of instructions // rooted at this node (ie, those accessible from a DFS traversal from this - // instruction). 'order' must also be a valid topological sort of these - // instructions (defs appear before uses). + // instruction). Instructions contained in 'order' which are not in the set of + // instructions rooted at this node are ignored. 'order' must also be a valid + // topological sort of these instructions (defs appear before uses) though + // need not be a DFS post-order. Status AcceptOrdered(DfsHloVisitor* visitor, const std::vector<const HloInstruction*>& order); + // Visit this instruction and only this instruction with the given visitor. + Status Visit(DfsHloVisitor* visitor); + // Returns the literal associated with this instruction. // // Note: only constant and parameter opcodes have an associated literal. @@ -693,9 +697,6 @@ class HloInstruction { // Accept above) allows us to distinguish the root of the traversal. Status AcceptInternal(DfsHloVisitor* visitor); - // Inner DFS traversal function called when visiting this HloInstruction. - Status AcceptInternalVisit(DfsHloVisitor* visitor); - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc new file mode 100644 index 0000000000..38106dbbb1 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -0,0 +1,363 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_ordering.h" + +#include <set> +#include <utility> +#include <vector> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/logical_buffer.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +PredecessorHloOrdering::PredecessorHloOrdering(const HloModule* module) + : module_(module) {} + +bool PredecessorHloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // Instructions in different computations are unordered. + if (a->parent() != b->parent()) { + return false; + } + // 'a' executes before 'b' if 'a' is in the strict predecessor set of 'b'. + return strict_predecessors_.at(b->parent())->IsReachable(b, a); +} + +string PredecessorHloOrdering::ToStringHelper(const string& name) const { + std::vector<string> pieces; + pieces.push_back(name); + for (auto& computation : module_->computations()) { + pieces.push_back(tensorflow::strings::Printf("computation %s:", + computation->name().c_str())); + const auto all = computation->MakeInstructionPostOrder(); + for (auto instruction : all) { + pieces.push_back(tensorflow::strings::Printf( + " %s strict predecessors:", instruction->name().c_str())); + for (auto predecessor : all) { + if (strict_predecessors_.at(computation.get()) + ->IsReachable(instruction, predecessor)) { + pieces.push_back( + tensorflow::strings::Printf(" %s", predecessor->name().c_str())); + } + } + } + } + return tensorflow::str_util::Join(pieces, "\n"); +} + +DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) + : PredecessorHloOrdering(module) { + // Compute predecessor relationships between all instructions to determine + // ordering based on dependencies. ExecutesBefore will return true iff there + // exists a path in the HLO computation graph from 'a' to 'b'. + for (auto& computation : module->computations()) { + strict_predecessors_.emplace(computation.get(), + computation->ComputeTransitiveOperands()); + } +} + +string DependencyHloOrdering::ToString() const { + return ToStringHelper("DependencyHloOrdering"); +} + +SequentialHloOrdering::SequentialHloOrdering( + const HloModule* module, const HloModuleSequence& module_sequence) + : module_(module) { + // Create a map from instruction to its order position. + for (auto computation_order : module_sequence) { + const std::vector<const HloInstruction*>& order = computation_order.second; + for (int i = 0; i < order.size(); ++i) { + DCHECK_EQ(0, order_position_.count(order[i])); + order_position_.emplace(order[i], i); + } + } +} + +bool SequentialHloOrdering::ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const { + // Instructions in different computations are unordered. + if (a->parent() != b->parent()) { + return false; + } + // If either instruction is not in the order, then 'a' and 'b' are unordered. + if (order_position_.count(a) == 0 || order_position_.count(b) == 0) { + return false; + } + return order_position_.at(a) < order_position_.at(b); +} + +string SequentialHloOrdering::ToString() const { + std::vector<string> pieces; + pieces.push_back("SequentialHloOrdering"); + for (auto& computation : module_->computations()) { + pieces.push_back(tensorflow::strings::Printf("computation %s order:", + computation->name().c_str())); + // Gather all instructions in the module sequence for this computation and + // sort them by their position. + std::vector<const HloInstruction*> instructions; + for (auto& instruction_position : order_position_) { + const HloInstruction* instruction = instruction_position.first; + if (instruction->parent() == computation.get()) { + instructions.push_back(instruction); + } + } + std::sort(instructions.begin(), instructions.end(), + [this](const HloInstruction* a, const HloInstruction* b) { + return order_position_.at(a) < order_position_.at(b); + }); + for (auto instruction : instructions) { + pieces.push_back( + tensorflow::strings::Printf(" %s", instruction->name().c_str())); + } + } + return tensorflow::str_util::Join(pieces, "\n"); +} + +namespace { + +// Class implementing a list scheduler of HLO instructions which produces a +// sequence which minimizes memory usage. +class ListScheduler { + public: + // Construct and return a memory-minimizing sequence of HLO instructions + // containing the given HLO computation. + static StatusOr<std::vector<const HloInstruction*>> Run( + const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) { + ListScheduler scheduler(computation, points_to_analysis, size_function); + return scheduler.CreateSchedule(); + } + + private: + // The scheduling priority of an instruction is first the number of bytes + // freed by scheduling the instruction, and second (tie-breaker) by the number + // of users. This is represented as a std::pair containing these two values + // (first element is the bytes freed). std::pair provides the necessary + // comparison operators. + using Priority = std::pair<int64, int64>; + + ListScheduler(const HloComputation& computation, + const TuplePointsToAnalysis& points_to_analysis, + const LogicalBuffer::SizeFunction& size_function) + : computation_(computation), + points_to_analysis_(points_to_analysis), + size_function_(size_function) { + // Create a map containing the LogicalBuffer uses for each HLO + // instruction. An HLO instruction "uses" a LogicalBuffer if the + // LogicalBuffer is in an operand of the instruction as indicated by + // points-to analysis. + for (auto& instruction : computation.instructions()) { + buffer_uses_.insert( + {instruction.get(), std::unordered_set<const LogicalBuffer*>()}); + for (auto* operand : instruction->operands()) { + for (const LogicalBuffer* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(operand)) { + buffer_uses_[instruction.get()].insert(buffer); + } + } + } + + // Create map containing the number of unscheduled uses (hlo instructions) + // of each logical buffer. + for (auto& instruction : computation.instructions()) { + for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( + instruction.get())) { + unscheduled_use_count_[buffer] = 0; + } + } + for (auto& instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + ++unscheduled_use_count_[buffer]; + } + } + + // Buffers live out of the computation have an implicit use at the end of + // the computation. + for (const LogicalBuffer* live_out_buffer : + points_to_analysis.GetPointsToSet(computation.root_instruction()) + .CreateFlattenedSet()) { + ++unscheduled_use_count_[live_out_buffer]; + } + } + + // Returns whether the memory used by the given buffer should be ignored by + // the scheduling heuristic. + bool IgnoreBuffer(const LogicalBuffer& buffer) { + return buffer.instruction()->opcode() == HloOpcode::kParameter || + buffer.instruction()->opcode() == HloOpcode::kConstant; + } + + // Return the number of bytes freed if the HLO instruction is scheduled. + int64 BytesFreedIfScheduled(const HloInstruction* instruction) { + int64 freed_bytes = 0; + // Sum the total size of the values last used by this instruction. + for (auto* buffer : buffer_uses_.at(instruction)) { + if (IgnoreBuffer(*buffer)) { + continue; + } + CHECK_GE(unscheduled_use_count_.at(buffer), 1); + if (unscheduled_use_count_.at(buffer) == 1) { + // This is the last use of the logical buffer. + freed_bytes += size_function_(*buffer); + } + } + // Then subtract the size of the value(s) defined by this instruction. + for (auto* buffer : + points_to_analysis_.GetBuffersDefinedByInstruction(instruction)) { + if (!IgnoreBuffer(*buffer)) { + freed_bytes -= size_function_(*buffer); + } + } + return freed_bytes; + } + + // Construct the scheduling priority of the given instruciton. + Priority GetPriority(const HloInstruction* instruction) { + return {BytesFreedIfScheduled(instruction), instruction->user_count()}; + } + + std::vector<const HloInstruction*> CreateSchedule() { + std::vector<const HloInstruction*> schedule; + + // Populate the ready list with instructions which have no operands. + std::list<const HloInstruction*> ready_list; + for (auto& instruction : computation_.instructions()) { + if (instruction->operand_count() == 0 && + instruction->control_predecessors().empty()) { + ready_list.push_back(instruction.get()); + } + } + + while (!ready_list.empty()) { + // Select the highest priority HLO instruction from the ready list. + auto best_it = ready_list.begin(); + Priority best_priority = GetPriority(*best_it); + for (auto ready_it = std::next(ready_list.begin()); + ready_it != ready_list.end(); ++ready_it) { + Priority priority = GetPriority(*ready_it); + if (priority > best_priority) { + best_it = ready_it; + best_priority = priority; + } + } + + // Remove the selected instruction from the ready list and add it to the + // schedule. + const HloInstruction* best = *best_it; + ready_list.erase(best_it); + schedule.push_back(best); + scheduled_instructions_.insert(best); + + // Update the unscheduled uses of the logical buffers. + for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { + CHECK_GT(unscheduled_use_count_.at(buffer), 0); + --unscheduled_use_count_[buffer]; + } + + // Add new instructions to ready list. + // TODO(b/34466113): Replace this with successors()/predecessors() when + // predecessor/successor methods are added to HloInstruction. This also + // will resolve the nondeterminism of using a set here assuming + // predecessors/successors is a vector. + std::set<HloInstruction*> successors = best->users(); + successors.insert(best->control_successors().begin(), + best->control_successors().end()); + for (auto* successor : successors) { + std::set<HloInstruction*> predecessors(successor->operands().begin(), + successor->operands().end()); + predecessors.insert(successor->control_predecessors().begin(), + successor->control_predecessors().end()); + bool is_ready = true; + for (auto* predecessor : predecessors) { + if (scheduled_instructions_.count(predecessor) == 0) { + is_ready = false; + break; + } + } + if (is_ready) { + ready_list.push_back(successor); + } + } + } + CHECK_EQ(schedule.size(), computation_.instructions().size()); + CHECK_EQ(scheduled_instructions_.size(), + computation_.instructions().size()); + + return schedule; + } + + const HloComputation& computation_; + const TuplePointsToAnalysis& points_to_analysis_; + const LogicalBuffer::SizeFunction& size_function_; + + // A map containing the LogicalBuffers that each instruction uses. + std::unordered_map<const HloInstruction*, + std::unordered_set<const LogicalBuffer*>> + buffer_uses_; + + // A map containing the count of unscheduled HLOs which using a particular + // LogicalBuffer. + std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_; + + // Set of instructions which have been scheduled. + std::unordered_set<const HloInstruction*> scheduled_instructions_; +}; + +} // namespace + +StatusOr<SequentialHloOrdering::HloModuleSequence> +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { + SequentialHloOrdering::HloModuleSequence sequence; + TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis, + TuplePointsToAnalysis::Run(&module)); + + for (auto& computation : module.computations()) { + TF_ASSIGN_OR_RETURN( + sequence[computation.get()], + ListScheduler::Run(*computation, *points_to_analysis, size_function)); + } + + return sequence; +} + +std::ostream& operator<<( + std::ostream& out, + const SequentialHloOrdering::HloModuleSequence& module_sequence) { + for (auto computation_pair : module_sequence) { + const HloComputation* computation = computation_pair.first; + const std::vector<const HloInstruction*>& computation_sequence = + computation_pair.second; + out << "Computation " << computation->name() << ":\n"; + for (auto* instruction : computation_sequence) { + out << " " << instruction->name() << "\n"; + } + } + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h new file mode 100644 index 0000000000..97f7c6060b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_ordering.h @@ -0,0 +1,172 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ + +#include <memory> +#include <string> +#include <utility> + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { + +// Abstract base class for describing a partial ordering of HLO +// instructions. Used to determine live range overlap of HLO instruction output +// buffers. +class HloOrdering { + public: + HloOrdering() = default; + virtual ~HloOrdering() = default; + + // Returns true if instruction 'a' executes before instruction 'b'. This is + // not reflexive, that is, an instruction does not execute before itself. + virtual bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const = 0; + virtual string ToString() const = 0; +}; + +// Base class for partial orderings implemented by a map of strict predecessors +// for each instruction. Subclasses should fill in strict_predecessors_. +class PredecessorHloOrdering : public HloOrdering { + public: + ~PredecessorHloOrdering() override = default; + + // Returns true if instruction 'a' executes before instruction 'b'. + // Instructions in different computations are not ordered. + bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const override; + + protected: + explicit PredecessorHloOrdering(const HloModule* module); + string ToStringHelper(const string& name) const; + + const HloModule* module_; + + // For each each computation in the module, this is the set of the + // instruction's strict predecessors. An instruction is not an element of its + // own strict predecessor set. + // + // Subclasses should fill this in to define the desired ordering. + tensorflow::gtl::FlatMap<const HloComputation*, + std::unique_ptr<HloComputation::ReachabilityMap>> + strict_predecessors_; +}; + +// An HLO ordering based on data dependencies in the HLO graph. In this partial +// order, instruction A executes before instruction B only if there is a path +// from A to B in the HLO graph. For example, given the following graph: +// +// param +// / \ +// negate exp +// \ / +// add +// +// DependencyHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before add +// exp executes before add +// add executes before nothing +// negate and exp are not ordered because the dependencies allow either to +// execute before the other (or in parallel). DependencyHloOrdering ordering +// allows maximum parallelism and enables any execution order which satisfies +// data dependencies. This requires pessimistic assumptions about buffer live +// ranges and can result in more memory used than more constrained orderings. +class DependencyHloOrdering : public PredecessorHloOrdering { + public: + explicit DependencyHloOrdering(const HloModule* module); + ~DependencyHloOrdering() override = default; + + string ToString() const override; +}; + +// An HLO ordering based on a total order of instructions in each computation. +// The computation total order is a sequencing of all of its instructions in +// the computation (eg, {inst0, inst1, inst2,...}) as in single-threaded +// execution. For example, given the following HLO graph: +// +// param +// / \ +// negate exp +// \ / +// add +// +// and the following sequence: +// +// {param, negate, exp, add} +// +// SequentialHloOrdering gives the following executes-before relations: +// param executes before negate, exp, and add +// negate executes before exp and add +// exp executes before add +// add executes before nothing +// This is more constrained than DependencyHloOrdering in this example because +// negate and exp are ordered (negate before exp). This enables param to share +// the same buffer as exp (param buffer is dead after exp). Generally, this +// ordering enables more buffer sharing (reduced memory usage) because buffer +// interference is reduced relative to DependencyHloOrdering. +class SequentialHloOrdering : public HloOrdering { + public: + // A sequence of instructions for each computation in the module. + using HloModuleSequence = + tensorflow::gtl::FlatMap<const HloComputation*, + std::vector<const HloInstruction*>>; + + SequentialHloOrdering(const HloModule* module, + const HloModuleSequence& module_sequence); + ~SequentialHloOrdering() override = default; + + // Instruction 'a' executes before 'b' if 'a' appears before 'b' in the + // instruction sequence for the computation. Instructions in different + // computations are unordered. + bool ExecutesBefore(const HloInstruction* a, + const HloInstruction* b) const override; + string ToString() const override; + + protected: + const HloModule* module_; + + // The position of every instruction in the HLO module in its respective + // computation sequence (a value of zero indicates the instruction is first in + // the sequence, etc). Instructions from all computations are contained in + // this map so more than one instruction may have the same position + // value. This is not a problem because ExecutesBefore also verifies + // instructions are in the same computation. + tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_; +}; + +// Returns an HloModuleSequence which seeks to minimize the memory required for +// the computation. size_function is the function returning the number of bytes +// required for a LogicalBuffer. +StatusOr<SequentialHloOrdering::HloModuleSequence> +CreateMemoryMinimizingSequence( + const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + +std::ostream& operator<<( + std::ostream& out, + const SequentialHloOrdering::HloModuleSequence& module_sequence); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ORDERING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc new file mode 100644 index 0000000000..425bee601a --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_ordering.h" + +#include <memory> +#include <string> + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +class HloOrderingTest : public HloTestBase {}; + +TEST_F(HloOrderingTest, LastUseScheduledFirst) { + // Tests scheduling of the following HLO code: + // + // %ab = abs(%param) + // %exp = exp(%param) + // %add = add(%ab, %exp) + // %negate = negate(%exp) + // %sub = subtract(%add, %negate) + // + // %add should be scheduled before %negate because %add is the last (and only) + // use of %ab. Scheduling %add first then frees up %ab's buffer. + const Shape vec = ShapeUtil::MakeShape(xla::F32, {42}); + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec, "param")); + auto ab = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kAbs, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kExp, param)); + + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kAdd, ab, exp)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vec, HloOpcode::kNegate, exp)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(vec, HloOpcode::kSubtract, add, negate)); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + TF_ASSIGN_OR_ASSERT_OK( + SequentialHloOrdering::HloModuleSequence sequence, + CreateMemoryMinimizingSequence(module, [](const LogicalBuffer& buffer) { + return ShapeUtil::ByteSizeOf(buffer.shape()); + })); + // Verify that all instructions are in the sequence. + EXPECT_EQ(module.entry_computation()->instruction_count(), + sequence.at(module.entry_computation()).size()); + + // The first instruction should be the parameter and the last the root "sub". + EXPECT_EQ(param, sequence.at(module.entry_computation()).front()); + EXPECT_EQ(sub, sequence.at(module.entry_computation()).back()); + + SequentialHloOrdering ordering(&module, sequence); + EXPECT_TRUE(ordering.ExecutesBefore(add, negate)); +} + +} // namespace + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/logical_buffer.h b/tensorflow/compiler/xla/service/logical_buffer.h index 21af9dcf66..35a9935f44 100644 --- a/tensorflow/compiler/xla/service/logical_buffer.h +++ b/tensorflow/compiler/xla/service/logical_buffer.h @@ -90,6 +90,9 @@ class LogicalBuffer { // unique value. using Id = int64; + // Function which returns the size of a logical buffer in bytes. + using SizeFunction = std::function<int64(const LogicalBuffer&)>; + LogicalBuffer(HloInstruction* instruction, const ShapeIndex& index, Id id) : instruction_(instruction), index_(index), id_(id) {} |