aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 07:46:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 07:50:58 -0700
commitaeab291563b0b4cc75c0f5fc73610a6595780570 (patch)
tree8b5e9cc861551cec0f4ba0cfcfa1b0e788109ae5 /tensorflow/compiler/xla/service/hlo_computation.cc
parentad4018ebb6b7f8ee4a38a7a0059bdbad28557732 (diff)
Handle communicating instructions in HloComputation::ComputeReachability
Send&recv instructions and cross-replica-sum instructions are imposing extra dependencies via the channel id or all reduce id. This CL teaches the reachability calculation logic in hlo computation to correctly account for these "invisible" dependencies. The main purpose is to stop multi output fusion from generating dependency cyclies via communicating instructions. PiperOrigin-RevId: 209593997
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc80
1 files changed, 79 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index bae78c94bd..70b18ff356 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -321,6 +321,7 @@ void ComputeComputationPostOrder(
enum State { kVisiting, kVisited };
void ComputeInstructionPostOrder(
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
std::vector<HloInstruction*> dfs_stack;
@@ -355,12 +356,67 @@ void ComputeInstructionPostOrder(
for (HloInstruction* op : current->control_predecessors()) {
dfs_stack.emplace_back(op);
}
+
+ // Add inputs for send->recv_done dependencies and cross-replica-sum
+ // dependencies.
+ switch (current->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies =
+ channel_dependency_map[current->channel_id()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = current->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
}
}
} // namespace
+std::map<int64, std::vector<HloInstruction*>>
+HloComputation::ComputeChannelDependencies() const {
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map;
+ for (const auto& instruction : instructions_) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ channel_dependency_map[instruction->channel_id()].push_back(
+ instruction.get());
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = instruction->all_reduce_id();
+ if (all_reduce_id) {
+ auto& dependencies = channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(instruction->operands(),
+ std::back_inserter(dependencies));
+ absl::c_copy(instruction->control_predecessors(),
+ std::back_inserter(dependencies));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ return channel_dependency_map;
+}
+
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
@@ -372,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
+ ComputeInstructionPostOrder(channel_dependency_map, &post_order,
+ instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
@@ -676,12 +733,33 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
const auto& all = MakeInstructionPostOrder();
auto result = absl::make_unique<HloReachabilityMap>(all);
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies = channel_dependency_map[hlo->channel_id()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = hlo->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;