aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h9
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h2
3 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index eed0112f62..fa5dcb0b36 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -216,6 +216,7 @@ Status HloModuleGroupMetadata::RecordInstructions() {
channels_.emplace_back();
channels_.back().id = hlo->channel_id();
channel_id_map_[hlo->channel_id()] = channels_.size() - 1;
+ max_channel_id_ = std::max(max_channel_id_, hlo->channel_id());
}
Channel& channel = channels_[channel_id_map_[hlo->channel_id()]];
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 15cdbdaade..c48a7ab0b5 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -173,6 +173,12 @@ class HloModuleGroupMetadata {
return companion_sets_;
}
+ // Returns all channels in the module group.
+ const std::vector<Channel>& channels() const { return channels_; }
+
+ // Returns the maximum channel id used in the module group.
+ int64 max_channel_id() const { return max_channel_id_; }
+
private:
Status Build();
@@ -221,6 +227,9 @@ class HloModuleGroupMetadata {
// Map from channel ids to the index in channels_.
tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+ // The maximum channel id used in the module group.
+ int64 max_channel_id_ = -1;
+
// The modules that this metadata was built from.
const std::vector<HloModule*>& modules_;
};
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 7126cb50cf..680f88048a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -403,7 +403,6 @@ class LayoutAssignment : public HloPassInterface {
Status CheckLayouts(HloModule* module);
ComputationLayout* entry_computation_layout_;
- ChannelLayoutConstraints* channel_layout_constraints_;
protected:
// Map containing the layouts of all computations assigned so
@@ -411,6 +410,7 @@ class LayoutAssignment : public HloPassInterface {
// handled before their caller instructions so the layouts of caller
// instructions can be set to match the computation.
std::map<HloComputation*, ComputationLayout> computation_layouts_;
+ ChannelLayoutConstraints* channel_layout_constraints_;
};
} // namespace xla