aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-29 21:24:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-29 21:27:20 -0700
commit9c509eedc3888d3846b2ab5ac2879268df9ff8cd (patch)
tree07a597f1409eaea8c38d7039e6580ff0f09e1b09 /tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
parent3f2ba2edf62dc394cfcb4b2606f1638389aa92e2 (diff)
Introduced kDomain HLO instruction set isolation to bound connected sets of instructions with similar attributes (ie, sharding).
This CL simply adds the infrastructure, but leaves the wire-on to a separate CL. PiperOrigin-RevId: 198503625
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc78
1 files changed, 57 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index b4cd3c730e..7d706b5fd0 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -87,6 +87,7 @@ Status HloModuleGroupMetadata::Build() {
<< "Peer instruction does not match the computation kind";
TF_RETURN_IF_ERROR(
AddCompanion(tracked->instruction(), peer_tracked->instruction()));
+ tracked_instructions_comms_[tracked->instruction()].push_back(hlo);
}
// Add the parents of companion instructions (they must be all of the same
@@ -116,23 +117,31 @@ Status HloModuleGroupMetadata::Build() {
}
Status HloModuleGroupMetadata::VerifyCompanionSets() const {
- // TODO(dlibenzi): Migrate this to use the device instead of module ID, once
- // the kDomain CL goes in.
for (const auto& companions : companion_sets_) {
// A companion set must be composed at most of an instruction per
// device/module.
std::unordered_set<int64> devices;
for (HloInstruction* instruction : *companions) {
- int64 device = GetModuleId(instruction->parent()->parent());
- if (!devices.insert(device).second) {
- std::stringstream ss;
- ss << "Companion set:" << std::endl;
- for (HloInstruction* hlo : *companions) {
- ss << " " << hlo->name() << " ("
- << GetModuleId(hlo->parent()->parent()) << ")" << std::endl;
+ // Go through all the communicating instructions (send, recv) of the given
+ // companion, and record their device.
+ std::unordered_set<int64> comm_devices;
+ for (HloInstruction* comm_instruction :
+ tracked_instructions_comms_.at(instruction)) {
+ auto device = GetInstructionDevice(*comm_instruction);
+ TF_RET_CHECK(device) << "Instruction " << comm_instruction->ToString()
+ << " does not have a device";
+ comm_devices.insert(*device);
+ }
+ for (int64 device : comm_devices) {
+ if (!devices.insert(device).second) {
+ std::stringstream ss;
+ ss << "Companion set:" << std::endl;
+ for (HloInstruction* hlo : *companions) {
+ ss << " " << hlo->name() << std::endl;
+ }
+ ss << "has multiple instructions on the same device";
+ return FailedPrecondition("%s", ss.str().c_str());
}
- ss << "has multiple instructions on the same device";
- return FailedPrecondition("%s", ss.str().c_str());
}
}
}
@@ -223,6 +232,21 @@ int64 HloModuleGroupMetadata::GetModuleId(const HloModule* module) const {
LOG(FATAL) << "unknown module";
}
+tensorflow::gtl::optional<int64> HloModuleGroupMetadata::GetInstructionDevice(
+ const HloInstruction& instruction) const {
+ // The module group metadata can be created in both "single module, multiple
+ // devices" and "multiple modules, no explicit devices" fashions.
+ // The API returns an optional even though the current implementation always
+ // returns a device, to account for cases where we cannot guess a device.
+ // In such cases the VerifyChannelInstructions() will return proper errors.
+ tensorflow::gtl::optional<int64> device =
+ instruction.sharding_unique_device();
+ if (!device) {
+ device = GetModuleId(instruction.parent()->parent());
+ }
+ return device;
+}
+
Status HloModuleGroupMetadata::RecordInstructions() {
const auto visitor = [this](HloInstruction* hlo) -> Status {
if (hlo->opcode() == HloOpcode::kWhile) {
@@ -346,26 +370,38 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
if (!ShapeUtil::Compatible(send_shape, recv_shape)) {
return FailedPrecondition("send/recv shapes do not match");
}
- const HloModule* send_module = channel.send->parent()->parent();
- const HloModule* send_done_module = channel.send_done->parent()->parent();
- if (send_module != send_done_module) {
+ auto send_device = GetInstructionDevice(*channel.send);
+ auto send_done_device = GetInstructionDevice(*channel.send_done);
+ if (!send_device) {
+ return FailedPrecondition("send instruction must have a device: %s",
+ channel.send->ToString().c_str());
+ }
+ if (!send_done_device) {
+ return FailedPrecondition("send_done instruction must have a device: %s",
+ channel.send_done->ToString().c_str());
+ }
+ if (*send_device != *send_done_device) {
return FailedPrecondition(
"send and send-done (channel=%lld) must be on the same device: %lld "
"vs. %lld",
- channel.id, GetModuleId(send_module), GetModuleId(send_done_module));
+ channel.id, *send_device, *send_done_device);
+ }
+ auto recv_device = GetInstructionDevice(*channel.recv);
+ auto recv_done_device = GetInstructionDevice(*channel.recv_done);
+ if (!recv_done_device) {
+ return FailedPrecondition("recv_done instruction must have a device: %s",
+ channel.recv_done->ToString().c_str());
}
- const HloModule* recv_module = channel.recv->parent()->parent();
- const HloModule* recv_done_module = channel.recv_done->parent()->parent();
- if (recv_module != recv_done_module) {
+ if (*recv_device != *recv_done_device) {
return FailedPrecondition(
"recv and recv-done (channel=%lld) must be on the same device: %lld "
"vs. %lld",
- channel.id, GetModuleId(recv_module), GetModuleId(recv_done_module));
+ channel.id, *recv_device, *recv_done_device);
}
- if (send_module == recv_module) {
+ if (*send_device == *recv_device) {
return FailedPrecondition(
"send and recv (channel=%lld) must be on different devices: %lld",
- channel.id, GetModuleId(send_module));
+ channel.id, *send_device);
}
}