diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.h')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_module_group_metadata.h | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index ffde3a332d..84f2d3f5fb 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -92,7 +92,7 @@ class HloModuleGroupMetadata { ComputationKind kind_ = ComputationKind::kInvalid; }; - // Represents a channel and the 4 instructions that form the channel. + // Represents a channel and the instructions that form the channel. struct Channel { int64 id = -1; HloInstruction* send = nullptr; @@ -118,13 +118,17 @@ class HloModuleGroupMetadata { // comment above on companion instructions. bool IsCompanionInstruction(HloInstruction* hlo) const; - // Returns true if the instruction is either a channel instruction or a - // companion instruction. + // Returns true if the instruction is either a channel instruction, a + // cross-module all-reduce instruction, or a companion instruction. bool InstructionCommunicates(HloInstruction* hlo) const; // Returns the Channel instance for the given channel id. const Channel& GetChannel(int64 channel_id) const; + // Returns the all-reduce instructions with the same all_reduce_id. + const std::vector<HloInstruction*>& GetAllReduceGroup( + int64 all_reduce_id) const; + // Returns the computation that contains the peer channel instructions for // the given instruction. // @@ -187,13 +191,14 @@ class HloModuleGroupMetadata { // Returns all channels in the module group. const std::vector<Channel>& channels() const { return channels_; } - // Returns the maximum channel id used in the module group. + // Returns the maximum channel id or all_reduce_id used in the module group. int64 max_channel_id() const { return max_channel_id_; } private: Status Build(); - // Record all channel instructions and While instructions. + // Record all channel instructions, cross-module AllReduce instructions, and + // While/Conditional/Call instructions. Status RecordInstructions(); // Verifies the given HloModules are well-formed and follow the specification, @@ -255,6 +260,9 @@ class HloModuleGroupMetadata { // Map from channel ids to the index in channels_. tensorflow::gtl::FlatMap<int64, int64> channel_id_map_; + // Map from all-reduce ids to the all reduce instructions. + tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_; + // The maximum channel id used in the module group. int64 max_channel_id_ = -1; |