diff options
3 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc index eed0112f62..fa5dcb0b36 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc @@ -216,6 +216,7 @@ Status HloModuleGroupMetadata::RecordInstructions() { channels_.emplace_back(); channels_.back().id = hlo->channel_id(); channel_id_map_[hlo->channel_id()] = channels_.size() - 1; + max_channel_id_ = std::max(max_channel_id_, hlo->channel_id()); } Channel& channel = channels_[channel_id_map_[hlo->channel_id()]]; diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h index 15cdbdaade..c48a7ab0b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h +++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h @@ -173,6 +173,12 @@ class HloModuleGroupMetadata { return companion_sets_; } + // Returns all channels in the module group. + const std::vector<Channel>& channels() const { return channels_; } + + // Returns the maximum channel id used in the module group. + int64 max_channel_id() const { return max_channel_id_; } + private: Status Build(); @@ -221,6 +227,9 @@ class HloModuleGroupMetadata { // Map from channel ids to the index in channels_. tensorflow::gtl::FlatMap<int64, int64> channel_id_map_; + // The maximum channel id used in the module group. + int64 max_channel_id_ = -1; + // The modules that this metadata was built from. const std::vector<HloModule*>& modules_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 7126cb50cf..680f88048a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -403,7 +403,6 @@ class LayoutAssignment : public HloPassInterface { Status CheckLayouts(HloModule* module); ComputationLayout* entry_computation_layout_; - ChannelLayoutConstraints* channel_layout_constraints_; protected: // Map containing the layouts of all computations assigned so @@ -411,6 +410,7 @@ class LayoutAssignment : public HloPassInterface { // handled before their caller instructions so the layouts of caller // instructions can be set to match the computation. std::map<HloComputation*, ComputationLayout> computation_layouts_; + ChannelLayoutConstraints* channel_layout_constraints_; }; } // namespace xla |