aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc26
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h5
2 files changed, 31 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 67f4c37413..a41cfa7591 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module_group_metadata.h"
+#include <sstream>
#include <string>
#include <utility>
@@ -110,6 +111,31 @@ Status HloModuleGroupMetadata::Build() {
TF_RETURN_IF_ERROR(computation->Accept(visitor));
}
}
+ TF_RETURN_IF_ERROR(VerifyCompanionSets());
+ return Status::OK();
+}
+
+Status HloModuleGroupMetadata::VerifyCompanionSets() const {
+ // TODO(dlibenzi): Migrate this to use the device instead of module ID, once
+ // the kDomain CL goes in.
+ for (const auto& companions : companion_sets_) {
+ // A companion set must be composed at most of an instruction per
+ // device/module.
+ std::unordered_set<int64> devices;
+ for (HloInstruction* instruction : *companions) {
+ int64 device = GetModuleId(instruction->parent()->parent());
+ if (!devices.insert(device).second) {
+ std::stringstream ss;
+ ss << "Companion set:" << std::endl;
+ for (HloInstruction* hlo : *companions) {
+ ss << " " << hlo->name() << " ("
+ << GetModuleId(hlo->parent()->parent()) << ")" << std::endl;
+ }
+ ss << "has multiple instructions on the same device";
+ return FailedPrecondition("%s", ss.str().c_str());
+ }
+ }
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 88ed9a2ecc..3ef4542f91 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -207,6 +207,11 @@ class HloModuleGroupMetadata {
// within the graph.
Status CheckCommunicatingInstruction(HloInstruction* instruction) const;
+ // Performs a consistency check on the companion sets built for the input
+ // modules. Check that a companion set does not include instructions from the
+ // same module/device.
+ Status VerifyCompanionSets() const;
+
// Retrieves a pointer to the stored TrackedInstruction associated with a
// tracked computation, or nullptr in case such computation is not tracked.
const TrackedInstruction* GetTrackedInstruction(