aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc80
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation_test.cc23
3 files changed, 108 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index bae78c94bd..70b18ff356 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -321,6 +321,7 @@ void ComputeComputationPostOrder(
enum State { kVisiting, kVisited };
void ComputeInstructionPostOrder(
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
tensorflow::gtl::FlatMap<HloInstruction*, State>* visited) {
std::vector<HloInstruction*> dfs_stack;
@@ -355,12 +356,67 @@ void ComputeInstructionPostOrder(
for (HloInstruction* op : current->control_predecessors()) {
dfs_stack.emplace_back(op);
}
+
+ // Add inputs for send->recv_done dependencies and cross-replica-sum
+ // dependencies.
+ switch (current->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies =
+ channel_dependency_map[current->channel_id()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = current->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ for (HloInstruction* op : dependencies) {
+ dfs_stack.emplace_back(op);
+ }
+ }
+ break;
+ }
+ default:
+ break;
+ }
}
}
} // namespace
+std::map<int64, std::vector<HloInstruction*>>
+HloComputation::ComputeChannelDependencies() const {
+ std::map<int64, std::vector<HloInstruction*>> channel_dependency_map;
+ for (const auto& instruction : instructions_) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ channel_dependency_map[instruction->channel_id()].push_back(
+ instruction.get());
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = instruction->all_reduce_id();
+ if (all_reduce_id) {
+ auto& dependencies = channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(instruction->operands(),
+ std::back_inserter(dependencies));
+ absl::c_copy(instruction->control_predecessors(),
+ std::back_inserter(dependencies));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ return channel_dependency_map;
+}
+
std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
@@ -372,7 +428,8 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
// users).
trace_instructions.push_back(instruction.get());
} else if (instruction->users().empty()) {
- ComputeInstructionPostOrder(&post_order, instruction.get(), &visited);
+ ComputeInstructionPostOrder(channel_dependency_map, &post_order,
+ instruction.get(), &visited);
}
}
post_order.insert(post_order.end(), trace_instructions.begin(),
@@ -676,12 +733,33 @@ std::unique_ptr<HloReachabilityMap> HloComputation::ComputeReachability()
const {
const auto& all = MakeInstructionPostOrder();
auto result = absl::make_unique<HloReachabilityMap>(all);
+ auto channel_dependency_map = ComputeChannelDependencies();
std::vector<HloInstruction*> inputs;
for (const HloInstruction* hlo : all) {
inputs.assign(hlo->operands().begin(), hlo->operands().end());
inputs.insert(inputs.end(), hlo->control_predecessors().begin(),
hlo->control_predecessors().end());
+
+ switch (hlo->opcode()) {
+ case HloOpcode::kRecvDone: {
+ const auto& dependencies = channel_dependency_map[hlo->channel_id()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ break;
+ }
+ case HloOpcode::kCrossReplicaSum: {
+ auto all_reduce_id = hlo->all_reduce_id();
+ if (all_reduce_id) {
+ const auto& dependencies =
+ channel_dependency_map[all_reduce_id.value()];
+ absl::c_copy(dependencies, std::back_inserter(inputs));
+ }
+ break;
+ }
+ default:
+ break;
+ }
+
result->FastSetReachabilityToUnion(inputs, hlo);
}
return result;
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 49ed65910f..faa33f0f90 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -399,6 +399,13 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
+ // Returns a map from channel-id to directed dependencies of the channel
+ // instructions. For send&recv pairs it means the send instruction and for
+ // cross-replica-sum the union of the dependencies for all participating
+ // instructions.
+ std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
+ const;
+
string name_;
int64 unique_id_;
HloInstruction* root_instruction_;
diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc
index e4c5470331..f7ed1b0316 100644
--- a/tensorflow/compiler/xla/service/hlo_computation_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc
@@ -691,6 +691,27 @@ TEST_F(HloComputationTest, StringificationCanonical) {
EXPECT_EQ(computation->ToString(options), expected_computation2);
}
-} // namespace
+TEST_F(HloComputationTest, ChannelReachability) {
+ const Shape shape = ShapeUtil::MakeShape(F32, {5, 7});
+ HloComputation::Builder builder("ChannelReachability");
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ auto token0 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto send =
+ builder.AddInstruction(HloInstruction::CreateSend(param, token0, 1));
+ auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send));
+ auto token1 = builder.AddInstruction(HloInstruction::CreateToken());
+ auto recv =
+ builder.AddInstruction(HloInstruction::CreateRecv(shape, token1, 1));
+ auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv));
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build(recv_done));
+ auto reachability = computation->ComputeReachability();
+ EXPECT_TRUE(reachability->IsReachable(param, recv_done));
+ EXPECT_FALSE(reachability->IsReachable(send, recv));
+ EXPECT_FALSE(reachability->IsReachable(send_done, recv));
+}
+
+} // namespace
} // namespace xla