diff options
author | 2018-09-05 17:17:23 -0700 | |
---|---|---|
committer | 2018-09-05 17:22:22 -0700 | |
commit | 6bd9f8fa0c17c55fc0c11ba0d9281cab1688b115 (patch) | |
tree | 1afd3dff710c4f63bae267807435abdcec784edb /tensorflow/compiler/xla/service/hlo_ordering.cc | |
parent | 017599d0a1fa7a7227a43649db67e96311033a4e (diff) |
Rollforward of cl/211656888 after fixing failing unit test.
*** Original change description ***
Add HloSchedule class representing a sequential order of an HloModule.
Currently we represent a sequential schedule of a module using a SequentialHloOrdering::HloModuleSequence which is a type alias of a bare map from HloComputation* to std::vector<HloInstruction*>. This CL replaces this with a proper class which results in better encap...
***
PiperOrigin-RevId: 211726890
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_ordering.cc | 86 |
1 files changed, 39 insertions, 47 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 0581d5c404..2105f7a349 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -18,6 +18,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore( VLOG(4) << a << " not defined before " << b; return false; } + + if (a.live_out_of_module()) { + VLOG(4) << a << " is live out of module and defined before " << b; + return false; + } + // All uses of 'a' must be before 'b' is defined. for (const HloUse& use : a.uses()) { if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(), @@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore( return false; } } + + if (a.instruction()->parent() == b.instruction()->parent()) { + for (const HloPosition& position : a.positions()) { + if (position.instruction == + a.instruction()->parent()->root_instruction()) { + VLOG(4) << a << " is live out of computation and defined before " << b + << " which is in same computation"; + return false; + } + } + } + return true; } @@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const { return ToStringHelper("DependencyHloOrdering"); } -SequentialHloOrdering::SequentialHloOrdering( - const HloModule* module, const HloModuleSequence& module_sequence) - : HloOrdering(module), module_sequence_(module_sequence) { +SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule) + : HloOrdering(schedule.module()), schedule_(schedule) { + Initialize(); +} + +SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule) + : HloOrdering(schedule.module()), schedule_(std::move(schedule)) { + Initialize(); +} + +void SequentialHloOrdering::Initialize() { // Create a map from instruction to its order position. - for (auto computation_order : module_sequence_) { - const std::vector<const HloInstruction*>& order = computation_order.second; + TF_DCHECK_OK(schedule_.Verify()); + for (const auto& computation_sequence : schedule_.sequences()) { + const std::vector<const HloInstruction*>& order = + computation_sequence.second.instructions(); for (int i = 0; i < order.size(); ++i) { - DCHECK_EQ(0, order_position_.count(order[i])); - order_position_.emplace(order[i], i); + InsertOrDie(&order_position_, order[i], i); } } } @@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation( const std::vector<const HloInstruction*>* SequentialHloOrdering::SequentialOrder( const HloComputation& computation) const { - auto find_it = module_sequence_.find(&computation); - return find_it == module_sequence_.end() ? nullptr : &find_it->second; + return schedule_.is_computation_scheduled(&computation) + ? &schedule_.sequence(&computation).instructions() + : nullptr; } string SequentialHloOrdering::ToString() const { - std::vector<string> pieces; - pieces.push_back("SequentialHloOrdering"); - for (auto* computation : module_->computations()) { - pieces.push_back( - absl::StrFormat("computation %s order:", computation->name())); - // Gather all instructions in the module sequence for this computation and - // sort them by their position. - std::vector<const HloInstruction*> instructions; - for (auto& instruction_position : order_position_) { - const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation) { - instructions.push_back(instruction); - } - } - std::sort(instructions.begin(), instructions.end(), - [this](const HloInstruction* a, const HloInstruction* b) { - return order_position_.at(a) < order_position_.at(b); - }); - for (auto instruction : instructions) { - pieces.push_back(absl::StrFormat(" %s", instruction->name())); - } - } - return absl::StrJoin(pieces, "\n"); -} - -std::ostream& operator<<( - std::ostream& out, - const SequentialHloOrdering::HloModuleSequence& module_sequence) { - for (auto computation_pair : module_sequence) { - const HloComputation* computation = computation_pair.first; - const std::vector<const HloInstruction*>& computation_sequence = - computation_pair.second; - out << "Computation " << computation->name() << ":\n"; - for (auto* instruction : computation_sequence) { - out << " " << instruction->name() << "\n"; - } - } - return out; + return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString()); } } // namespace xla |