aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h18
1 files changed, 13 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index ffde3a332d..84f2d3f5fb 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -92,7 +92,7 @@ class HloModuleGroupMetadata {
ComputationKind kind_ = ComputationKind::kInvalid;
};
- // Represents a channel and the 4 instructions that form the channel.
+ // Represents a channel and the instructions that form the channel.
struct Channel {
int64 id = -1;
HloInstruction* send = nullptr;
@@ -118,13 +118,17 @@ class HloModuleGroupMetadata {
// comment above on companion instructions.
bool IsCompanionInstruction(HloInstruction* hlo) const;
- // Returns true if the instruction is either a channel instruction or a
- // companion instruction.
+ // Returns true if the instruction is either a channel instruction, a
+ // cross-module all-reduce instruction, or a companion instruction.
bool InstructionCommunicates(HloInstruction* hlo) const;
// Returns the Channel instance for the given channel id.
const Channel& GetChannel(int64 channel_id) const;
+ // Returns the all-reduce instructions with the same all_reduce_id.
+ const std::vector<HloInstruction*>& GetAllReduceGroup(
+ int64 all_reduce_id) const;
+
// Returns the computation that contains the peer channel instructions for
// the given instruction.
//
@@ -187,13 +191,14 @@ class HloModuleGroupMetadata {
// Returns all channels in the module group.
const std::vector<Channel>& channels() const { return channels_; }
- // Returns the maximum channel id used in the module group.
+ // Returns the maximum channel id or all_reduce_id used in the module group.
int64 max_channel_id() const { return max_channel_id_; }
private:
Status Build();
- // Record all channel instructions and While instructions.
+ // Record all channel instructions, cross-module AllReduce instructions, and
+ // While/Conditional/Call instructions.
Status RecordInstructions();
// Verifies the given HloModules are well-formed and follow the specification,
@@ -255,6 +260,9 @@ class HloModuleGroupMetadata {
// Map from channel ids to the index in channels_.
tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+ // Map from all-reduce ids to the all reduce instructions.
+ tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_;
+
// The maximum channel id used in the module group.
int64 max_channel_id_ = -1;