diff options
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 83 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.h | 8 |
2 files changed, 55 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 8c875698eb..80ed6d6832 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -731,6 +731,55 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { return tensorflow::Status::OK(); } +Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { + auto* while_cond = instruction->while_condition(); + auto* while_body = instruction->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %lld : %s", + while_cond->num_parameters(), while_cond->ToString().c_str()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %lld : %s", + while_body->num_parameters(), while_body->ToString().c_str()); + } + if (instruction->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %lld : %s", + instruction->operand_count(), instruction->ToString().c_str()); + } + auto* init = instruction->operand(0); + auto* cond_param = while_cond->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) { + return FailedPrecondition( + "While condition's parameter must have the same shape as the " + "loop's 'init'. init: %s, param: %s", + init->ToString().c_str(), cond_param->ToString().c_str()); + } + auto* cond_root = while_cond->root_instruction(); + if (!ShapeUtil::Compatible(cond_root->shape(), + ShapeUtil::MakeShape(PRED, {}))) { + return FailedPrecondition("While condition should have shape PRED: %s", + cond_root->ToString().c_str()); + } + auto* body_param = while_body->parameter_instruction(0); + if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) { + return FailedPrecondition( + "While body's parameter must have the same shape as the loop's" + " 'init'. init: %s, param: %s", + init->ToString().c_str(), body_param->ToString().c_str()); + } + auto* body_root = while_body->root_instruction(); + if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) { + return FailedPrecondition( + "While body should have same shape as the loop's 'init'." + "init: %s, body: %s", + init->ToString().c_str(), body_root->ToString().c_str()); + } + return tensorflow::Status::OK(); +} + StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); @@ -771,39 +820,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { << instruction->dimensions().size() << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); } else if (instruction->opcode() == HloOpcode::kWhile) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - TF_RET_CHECK(while_cond->num_parameters() == 1) - << "While condition must have exactly 1 parameter; had " - << while_cond->num_parameters() << ": " << while_cond->ToString(); - TF_RET_CHECK(while_body->num_parameters() == 1) - << "While body must have exactly 1 parameter; had " - << while_body->num_parameters() << ": " << while_body->ToString(); - TF_RET_CHECK(instruction->operand_count() == 1) - << "While loop must have exactly one operand; had " - << instruction->operand_count() << ": " << instruction->ToString(); - - auto* init = instruction->operand(0); - auto* cond_param = while_cond->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape())) - << "While condition's parameter must have the same shape as the " - "loop's 'init'. init: " - << init->ToString() << ", param: " << cond_param->ToString(); - auto* cond_root = while_cond->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(), - ShapeUtil::MakeShape(PRED, {}))) - << "While condition should have shape PRED: " - << cond_root->ToString(); - - auto* body_param = while_body->parameter_instruction(0); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape())) - << "While body's parameter must have the same shape as the loop's " - "'init'. init: " - << init->ToString() << ", param: " << body_param->ToString(); - auto* body_root = while_body->root_instruction(); - TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape())) - << "While body should have same shape as the loop's 'init'. init: " - << init->ToString() << ", body: " << body_root->ToString(); + TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); } auto previous = instructions.find(instruction->name()); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 1dd7ec3c51..1ec55a9bdc 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -102,7 +102,7 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckTernaryShape(const HloInstruction* instruction); Status CheckVariadicShape(const HloInstruction* instruction); - // Checks if the given two instructions shares the same channel id. + // Checks if the given two instructions share the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); @@ -144,9 +144,11 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; + Status CheckWhileInstruction(HloInstruction* instruction); + // Creates a ShapeVerifier that checks that shapes match inferred - // expectations. This is a factory function because ShapeVerifier, Note that - // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object + // expectations. This is a factory function because ShapeVerifier, + // being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. ShapeVerifierFactory shape_verifier_factory_; }; |