aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 07:46:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 07:50:58 -0700
commitaeab291563b0b4cc75c0f5fc73610a6595780570 (patch)
tree8b5e9cc861551cec0f4ba0cfcfa1b0e788109ae5 /tensorflow/compiler/xla/service/hlo_computation.h
parentad4018ebb6b7f8ee4a38a7a0059bdbad28557732 (diff)
Handle communicating instructions in HloComputation::ComputeReachability
Send&recv instructions and cross-replica-sum instructions are imposing extra dependencies via the channel id or all reduce id. This CL teaches the reachability calculation logic in hlo computation to correctly account for these "invisible" dependencies. The main purpose is to stop multi output fusion from generating dependency cyclies via communicating instructions. PiperOrigin-RevId: 209593997
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 49ed65910f..faa33f0f90 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -399,6 +399,13 @@ class HloComputation {
// Internal helper to collect unreachable roots.
std::vector<HloInstruction*> CollectUnreachableRoots() const;
+ // Returns a map from channel-id to directed dependencies of the channel
+ // instructions. For send&recv pairs it means the send instruction and for
+ // cross-replica-sum the union of the dependencies for all participating
+ // instructions.
+ std::map<int64, std::vector<HloInstruction*>> ComputeChannelDependencies()
+ const;
+
string name_;
int64 unique_id_;
HloInstruction* root_instruction_;