aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_ordering.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_ordering.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc86
1 files changed, 39 insertions, 47 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index 0581d5c404..2105f7a349 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -252,6 +253,12 @@ bool HloOrdering::LiveRangeStrictlyBefore(
VLOG(4) << a << " not defined before " << b;
return false;
}
+
+ if (a.live_out_of_module()) {
+ VLOG(4) << a << " is live out of module and defined before " << b;
+ return false;
+ }
+
// All uses of 'a' must be before 'b' is defined.
for (const HloUse& use : a.uses()) {
if (dataflow.DoesNotUseOperandBuffer(a.instruction(), a.index(),
@@ -264,6 +271,18 @@ bool HloOrdering::LiveRangeStrictlyBefore(
return false;
}
}
+
+ if (a.instruction()->parent() == b.instruction()->parent()) {
+ for (const HloPosition& position : a.positions()) {
+ if (position.instruction ==
+ a.instruction()->parent()->root_instruction()) {
+ VLOG(4) << a << " is live out of computation and defined before " << b
+ << " which is in same computation";
+ return false;
+ }
+ }
+ }
+
return true;
}
@@ -336,15 +355,24 @@ string DependencyHloOrdering::ToString() const {
return ToStringHelper("DependencyHloOrdering");
}
-SequentialHloOrdering::SequentialHloOrdering(
- const HloModule* module, const HloModuleSequence& module_sequence)
- : HloOrdering(module), module_sequence_(module_sequence) {
+SequentialHloOrdering::SequentialHloOrdering(const HloSchedule& schedule)
+ : HloOrdering(schedule.module()), schedule_(schedule) {
+ Initialize();
+}
+
+SequentialHloOrdering::SequentialHloOrdering(HloSchedule&& schedule)
+ : HloOrdering(schedule.module()), schedule_(std::move(schedule)) {
+ Initialize();
+}
+
+void SequentialHloOrdering::Initialize() {
// Create a map from instruction to its order position.
- for (auto computation_order : module_sequence_) {
- const std::vector<const HloInstruction*>& order = computation_order.second;
+ TF_DCHECK_OK(schedule_.Verify());
+ for (const auto& computation_sequence : schedule_.sequences()) {
+ const std::vector<const HloInstruction*>& order =
+ computation_sequence.second.instructions();
for (int i = 0; i < order.size(); ++i) {
- DCHECK_EQ(0, order_position_.count(order[i]));
- order_position_.emplace(order[i], i);
+ InsertOrDie(&order_position_, order[i], i);
}
}
}
@@ -362,49 +390,13 @@ bool SequentialHloOrdering::ExecutesBeforeInSameComputation(
const std::vector<const HloInstruction*>*
SequentialHloOrdering::SequentialOrder(
const HloComputation& computation) const {
- auto find_it = module_sequence_.find(&computation);
- return find_it == module_sequence_.end() ? nullptr : &find_it->second;
+ return schedule_.is_computation_scheduled(&computation)
+ ? &schedule_.sequence(&computation).instructions()
+ : nullptr;
}
string SequentialHloOrdering::ToString() const {
- std::vector<string> pieces;
- pieces.push_back("SequentialHloOrdering");
- for (auto* computation : module_->computations()) {
- pieces.push_back(
- absl::StrFormat("computation %s order:", computation->name()));
- // Gather all instructions in the module sequence for this computation and
- // sort them by their position.
- std::vector<const HloInstruction*> instructions;
- for (auto& instruction_position : order_position_) {
- const HloInstruction* instruction = instruction_position.first;
- if (instruction->parent() == computation) {
- instructions.push_back(instruction);
- }
- }
- std::sort(instructions.begin(), instructions.end(),
- [this](const HloInstruction* a, const HloInstruction* b) {
- return order_position_.at(a) < order_position_.at(b);
- });
- for (auto instruction : instructions) {
- pieces.push_back(absl::StrFormat(" %s", instruction->name()));
- }
- }
- return absl::StrJoin(pieces, "\n");
-}
-
-std::ostream& operator<<(
- std::ostream& out,
- const SequentialHloOrdering::HloModuleSequence& module_sequence) {
- for (auto computation_pair : module_sequence) {
- const HloComputation* computation = computation_pair.first;
- const std::vector<const HloInstruction*>& computation_sequence =
- computation_pair.second;
- out << "Computation " << computation->name() << ":\n";
- for (auto* instruction : computation_sequence) {
- out << " " << instruction->name() << "\n";
- }
- }
- return out;
+ return absl::StrCat("SequentialHloOrdering\n", schedule_.ToString());
}
} // namespace xla