aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc83
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h8
2 files changed, 55 insertions, 36 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 8c875698eb..80ed6d6832 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -731,6 +731,55 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
return tensorflow::Status::OK();
}
+Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
+ auto* while_cond = instruction->while_condition();
+ auto* while_body = instruction->while_body();
+ if (while_cond->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While condition must have exactly 1 parameter; had %lld : %s",
+ while_cond->num_parameters(), while_cond->ToString().c_str());
+ }
+ if (while_body->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While body must have exactly 1 parameter; had %lld : %s",
+ while_body->num_parameters(), while_body->ToString().c_str());
+ }
+ if (instruction->operand_count() != 1) {
+ return FailedPrecondition(
+ "While loop must have exactly one operand; had %lld : %s",
+ instruction->operand_count(), instruction->ToString().c_str());
+ }
+ auto* init = instruction->operand(0);
+ auto* cond_param = while_cond->parameter_instruction(0);
+ if (!ShapeUtil::Compatible(init->shape(), cond_param->shape())) {
+ return FailedPrecondition(
+ "While condition's parameter must have the same shape as the "
+ "loop's 'init'. init: %s, param: %s",
+ init->ToString().c_str(), cond_param->ToString().c_str());
+ }
+ auto* cond_root = while_cond->root_instruction();
+ if (!ShapeUtil::Compatible(cond_root->shape(),
+ ShapeUtil::MakeShape(PRED, {}))) {
+ return FailedPrecondition("While condition should have shape PRED: %s",
+ cond_root->ToString().c_str());
+ }
+ auto* body_param = while_body->parameter_instruction(0);
+ if (!ShapeUtil::Compatible(init->shape(), body_param->shape())) {
+ return FailedPrecondition(
+ "While body's parameter must have the same shape as the loop's"
+ " 'init'. init: %s, param: %s",
+ init->ToString().c_str(), body_param->ToString().c_str());
+ }
+ auto* body_root = while_body->root_instruction();
+ if (!ShapeUtil::Compatible(init->shape(), body_root->shape())) {
+ return FailedPrecondition(
+ "While body should have same shape as the loop's 'init'."
+ "init: %s, body: %s",
+ init->ToString().c_str(), body_root->ToString().c_str());
+ }
+ return tensorflow::Status::OK();
+}
+
StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
@@ -771,39 +820,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
<< instruction->dimensions().size()
<< " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
} else if (instruction->opcode() == HloOpcode::kWhile) {
- auto* while_cond = instruction->while_condition();
- auto* while_body = instruction->while_body();
- TF_RET_CHECK(while_cond->num_parameters() == 1)
- << "While condition must have exactly 1 parameter; had "
- << while_cond->num_parameters() << ": " << while_cond->ToString();
- TF_RET_CHECK(while_body->num_parameters() == 1)
- << "While body must have exactly 1 parameter; had "
- << while_body->num_parameters() << ": " << while_body->ToString();
- TF_RET_CHECK(instruction->operand_count() == 1)
- << "While loop must have exactly one operand; had "
- << instruction->operand_count() << ": " << instruction->ToString();
-
- auto* init = instruction->operand(0);
- auto* cond_param = while_cond->parameter_instruction(0);
- TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape()))
- << "While condition's parameter must have the same shape as the "
- "loop's 'init'. init: "
- << init->ToString() << ", param: " << cond_param->ToString();
- auto* cond_root = while_cond->root_instruction();
- TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(),
- ShapeUtil::MakeShape(PRED, {})))
- << "While condition should have shape PRED: "
- << cond_root->ToString();
-
- auto* body_param = while_body->parameter_instruction(0);
- TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape()))
- << "While body's parameter must have the same shape as the loop's "
- "'init'. init: "
- << init->ToString() << ", param: " << body_param->ToString();
- auto* body_root = while_body->root_instruction();
- TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape()))
- << "While body should have same shape as the loop's 'init'. init: "
- << init->ToString() << ", body: " << body_root->ToString();
+ TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
}
auto previous = instructions.find(instruction->name());
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 1dd7ec3c51..1ec55a9bdc 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -102,7 +102,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status CheckTernaryShape(const HloInstruction* instruction);
Status CheckVariadicShape(const HloInstruction* instruction);
- // Checks if the given two instructions shares the same channel id.
+ // Checks if the given two instructions share the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2);
@@ -144,9 +144,11 @@ class HloVerifier : public HloPassInterface {
// CHECKs various invariants of a fusion instruction.
Status CheckFusionInstruction(HloInstruction* fusion) const;
+ Status CheckWhileInstruction(HloInstruction* instruction);
+
// Creates a ShapeVerifier that checks that shapes match inferred
- // expectations. This is a factory function because ShapeVerifier, Note that
- // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object
+ // expectations. This is a factory function because ShapeVerifier,
+ // being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
ShapeVerifierFactory shape_verifier_factory_;
};