aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-11 11:04:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 11:07:41 -0700
commite1562e72c197ec830547a051ddfe0f720acb9f67 (patch)
tree18abff05955efb8a329028fd15beffc6b638594a /tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
parent8480a96e1fb43edd26846a6c6d986f9408f8a2db (diff)
Allow communicating instructions within a kCall computation.
PiperOrigin-RevId: 196278635
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_module_group_metadata.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.cc38
1 files changed, 23 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
index 54c34ce116..67f4c37413 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.cc
@@ -47,6 +47,9 @@ string HloModuleGroupMetadata::TrackedInstruction::ToString() const {
case ComputationKind::kConditionalFalse:
repr += ":CONDITIONAL_FALSE";
break;
+ case ComputationKind::kCallFunction:
+ repr += ":CALL";
+ break;
}
return repr;
}
@@ -206,6 +209,9 @@ Status HloModuleGroupMetadata::RecordInstructions() {
TrackedInstruction(hlo, ComputationKind::kConditionalTrue);
tracked_instructions_[hlo->false_computation()] =
TrackedInstruction(hlo, ComputationKind::kConditionalFalse);
+ } else if (hlo->opcode() == HloOpcode::kCall) {
+ tracked_instructions_[hlo->to_apply()] =
+ TrackedInstruction(hlo, ComputationKind::kCallFunction);
}
if (!IsChannelInstruction(hlo)) {
return Status::OK();
@@ -258,7 +264,8 @@ Status HloModuleGroupMetadata::RecordInstructions() {
Status HloModuleGroupMetadata::AddCompanion(HloInstruction* instruction1,
HloInstruction* instruction2) {
TF_RET_CHECK(instruction1->opcode() == HloOpcode::kWhile ||
- instruction1->opcode() == HloOpcode::kConditional);
+ instruction1->opcode() == HloOpcode::kConditional ||
+ instruction1->opcode() == HloOpcode::kCall);
VLOG(2) << "adding as companions:" << instruction1->ToString() << " and "
<< instruction2->ToString();
@@ -336,21 +343,11 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
}
}
- // Check if channel instructions are used only in allowed computations.
- const auto allowed = [this](HloInstruction* hlo) {
- HloComputation* computation = hlo->parent();
- const HloModule* module = computation->parent();
- if (module->entry_computation() == computation ||
- tracked_instructions_.count(computation) > 0) {
- return true;
- }
- return false;
- };
for (const Channel& channel : channels_) {
- if (!allowed(channel.send) || !allowed(channel.send_done) ||
- !allowed(channel.recv) || !allowed(channel.recv_done)) {
- return FailedPrecondition("channel is used in disallowed computation");
- }
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.send_done));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv));
+ TF_RETURN_IF_ERROR(CheckCommunicatingInstruction(channel.recv_done));
}
// Check if the nest levels match for each channel.
for (const Channel& channel : channels_) {
@@ -368,4 +365,15 @@ Status HloModuleGroupMetadata::VerifyChannelInstructions() {
return Status::OK();
}
+Status HloModuleGroupMetadata::CheckCommunicatingInstruction(
+ HloInstruction* instruction) const {
+ HloComputation* computation = instruction->parent();
+ const HloModule* module = computation->parent();
+ if (module->entry_computation() == computation ||
+ tracked_instructions_.count(computation) > 0) {
+ return Status::OK();
+ }
+ return FailedPrecondition("channel is used in disallowed computation");
+}
+
} // namespace xla