aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-04 08:56:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 09:01:03 -0700
commit80c9eec9b2475630f83a596f77a906c8075f8e6c (patch)
tree23685affd41b566df670435fef6425de4bac5569 /tensorflow/compiler
parentdcd7dd2d2e1ed7d8c26dd22dbbd2bac269c42e1e (diff)
Remove CHECKs from HloInstruction constructors.
Move these checks to RET_CHECKs in the HloVerifier. Added a new visitor class InstructionVerifier inside of hlo_verifier.cc for handling these random non-result-shape verifications. PiperOrigin-RevId: 215745043
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc456
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h11
4 files changed, 248 insertions, 232 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 68d0979f5c..152d8eacdb 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -643,14 +643,6 @@ HloTransposeInstruction::HloTransposeInstruction(
absl::Span<const int64> dimensions)
: HloInstruction(HloOpcode::kTranspose, shape),
dimensions_(dimensions.begin(), dimensions.end()) {
- CHECK_EQ(shape.dimensions().size(), dimensions.size());
- CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size());
- CHECK(std::equal(operand->shape().dimensions().begin(),
- operand->shape().dimensions().end(),
- Permute(dimensions, shape.dimensions()).begin()))
- << "shape: " << ShapeUtil::HumanString(shape)
- << ", operand->shape(): " << ShapeUtil::HumanString(shape)
- << ", dimensions: {" << StrJoin(dimensions, ", ") << "}";
AppendOperand(operand);
}
@@ -1491,7 +1483,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl(
HloGetTupleElementInstruction::HloGetTupleElementInstruction(
const Shape& shape, HloInstruction* operand, int64 index)
: HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) {
- CHECK(ShapeUtil::IsTuple(operand->shape()));
AppendOperand(operand);
}
@@ -1613,9 +1604,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
: HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()),
outfeed_shape_(outfeed_shape),
outfeed_config_(outfeed_config) {
- CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape))
- << "Outfeed shape " << outfeed_shape
- << " must be compatible with operand shape " << operand->shape();
AppendOperand(operand);
AppendOperand(token_operand);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index ab168800f6..e169604072 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -896,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction {
absl::string_view outfeed_config);
// Returns the shape for the Outfeed instruction.
const Shape& outfeed_shape() const {
- TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_));
return outfeed_shape_;
}
// Returns the config for the Outfeed instruction.
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index a7727824fe..b5498bb936 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -763,7 +763,136 @@ Status VerifyHloStructure(HloModule* module) {
return Status::OK();
}
-Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
+namespace {
+
+// Returns true if the given Shape has a TOKEN shape as any subshape.
+bool ShapeContainsToken(const Shape& shape) {
+ bool contains_token = false;
+ ShapeUtil::ForEachSubshape(
+ shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
+ if (ShapeUtil::IsToken(subshape)) {
+ contains_token = true;
+ }
+ });
+ return contains_token;
+}
+
+// Verifies that all types entering and exiting the entry computation are
+// legal.
+Status VerifyEntryAndExitShapes(const HloModule& module) {
+ // Tokens cannot be passed as entry parameters.
+ // TODO(b/80000000): Remove this constraint.
+ for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
+ HloInstruction* param =
+ module.entry_computation()->parameter_instruction(i);
+ if (ShapeContainsToken(param->shape())) {
+ return InternalError(
+ "Entry parameter %d is or contains a token shape: %s", i,
+ ShapeUtil::HumanString(param->shape()));
+ }
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions share the same channel id.
+Status CheckSameChannel(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ if (instr1->channel_id() != instr2->channel_id()) {
+ return InternalError(
+ "Expected to have the same channel id, actual channel ids are: %s "
+ "(%d), %s (%d)",
+ instr1->ToString(), instr1->channel_id(), instr2->ToString(),
+ instr2->channel_id());
+ }
+ return Status::OK();
+}
+
+// Checks if the given two instructions have the same is_host_transfer
+// attribute value. Intsructions must be send/recv instructions or their
+// 'done' variant.
+Status CheckSameIsHostTransfer(const HloInstruction* instr1,
+ const HloInstruction* instr2) {
+ const HloSendRecvInstruction* send_recv1 =
+ DynCast<const HloSendRecvInstruction>(instr1);
+ const HloSendRecvInstruction* send_recv2 =
+ DynCast<const HloSendRecvInstruction>(instr2);
+ TF_RET_CHECK(send_recv1 != nullptr);
+ TF_RET_CHECK(send_recv2 != nullptr);
+ if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
+ return InternalError(
+ "Expected instructions to have the same is-host-transfer property: "
+ "%s, "
+ "%s ",
+ instr1->ToString(), instr2->ToString());
+ }
+ return Status::OK();
+}
+
+// Checks various invariants of send and recv instructions.
+Status VerifySendsAndRecvs(const HloModule& module) {
+ absl::flat_hash_map<int64, const HloInstruction*> host_channels;
+ // Host send/recv instructions must have their own unique channel.
+ auto check_unique_host_channel = [&](const HloInstruction* instruction) {
+ const HloSendRecvInstruction* sendrecv =
+ DynCast<const HloSendRecvInstruction>(instruction);
+ if (sendrecv->is_host_transfer()) {
+ auto it_inserted =
+ host_channels.insert({sendrecv->channel_id(), sendrecv});
+ if (!it_inserted.second) {
+ return FailedPrecondition(
+ "Channel %d is used for multiple host send/recv instructions: "
+ "%s "
+ "and "
+ "%s",
+ sendrecv->channel_id(), sendrecv->ToString(),
+ it_inserted.first->second->ToString());
+ }
+ }
+
+ return Status::OK();
+ };
+
+ // Send/Recv instruction must have a single user: the corresponding
+ // SendDone/RecvDone. with matching channel.
+ for (const HloComputation* computation : module.computations()) {
+ for (const HloInstruction* instruction : computation->instructions()) {
+ switch (instruction->opcode()) {
+ case HloOpcode::kSend: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* send_done = instruction->users().front();
+ TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
+ break;
+ }
+ case HloOpcode::kRecv: {
+ TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
+ TF_RET_CHECK(instruction->users().size() == 1);
+ const HloInstruction* recv_done = instruction->users().front();
+ TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
+ TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
+ TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
+ break;
+ }
+ case HloOpcode::kSendDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
+ break;
+ case HloOpcode::kRecvDone:
+ TF_RET_CHECK(instruction->operands().size() == 1);
+ TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
+ break;
+ default:
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+// CHECKs various invariants of a fusion instruction.
+Status CheckFusionInstruction(HloInstruction* fusion) {
// The parent fusion instruction of the fusion computation must be 'fusion'.
HloComputation* fused_computation = fusion->fused_instructions_computation();
if (fusion != fused_computation->FusionInstruction()) {
@@ -866,50 +995,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const {
}
}
+ TF_RET_CHECK(fusion->called_computations() ==
+ absl::Span<HloComputation* const>(
+ {fusion->fused_instructions_computation()}))
+ << "Fusion HLO calls computations other than the "
+ "fused_instructions_computation: "
+ << fusion->ToString() << " fusion->fused_instructions_computation(): "
+ << fusion->fused_instructions_computation()->ToString()
+ << " fusion->called_computations(): "
+ << ComputationsToString(fusion->called_computations());
+
+ for (const auto& fused : fusion->fused_instructions()) {
+ TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation())
+ << "Fused HLO was missing a parent: " << fused->ToString()
+ << " parent: " << fused->parent()
+ << " computation: " << fusion->parent();
+ }
+
// TODO(b/65423525): We'd like to check that all operands are distinct.
// This is currently disabled due to the invariant being violated by
// multi-output fusion.
return Status::OK();
}
-Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) {
- auto* while_cond = instruction->while_condition();
- auto* while_body = instruction->while_body();
- if (while_cond->num_parameters() != 1) {
- return FailedPrecondition(
- "While condition must have exactly 1 parameter; had %d : %s",
- while_cond->num_parameters(), while_cond->ToString());
- }
- if (while_body->num_parameters() != 1) {
- return FailedPrecondition(
- "While body must have exactly 1 parameter; had %d : %s",
- while_body->num_parameters(), while_body->ToString());
- }
- if (instruction->operand_count() != 1) {
- return FailedPrecondition(
- "While loop must have exactly one operand; had %d : %s",
- instruction->operand_count(), instruction->ToString());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) {
- if (instruction->true_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "True computation %s of %s must have 1 parameter insted of %d",
- instruction->true_computation()->name(), instruction->ToString(),
- instruction->true_computation()->num_parameters());
- }
- if (instruction->false_computation()->num_parameters() != 1) {
- return FailedPrecondition(
- "False computation %s of %s must have 1 parameter insted of %d",
- instruction->false_computation()->name(), instruction->ToString(),
- instruction->false_computation()->num_parameters());
- }
- return Status::OK();
-}
-
-Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
+// Checks that the non-scalar operand shapes are compatible to the output
+// shape, i.e., that there are no implicit broadcasts of size-one dimensions.
+Status CheckElementwiseInstruction(HloInstruction* instruction) {
const Shape& out_shape = instruction->shape();
for (HloInstruction* operand : instruction->operands()) {
const Shape& operand_shape = operand->shape();
@@ -926,133 +1037,114 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) {
return Status::OK();
}
-namespace {
+// Visitor which verifies various fields on the HLO instruction. This class does
+// not check result shape as that is checked in the ShapeVerifier.
+class InstructionVerifier : public DfsHloVisitorWithDefault {
+ public:
+ InstructionVerifier() {}
-// Returns true if the given Shape has a TOKEN shape as any subshape.
-bool ShapeContainsToken(const Shape& shape) {
- bool contains_token = false;
- ShapeUtil::ForEachSubshape(
- shape, [&contains_token](const Shape& subshape, const ShapeIndex&) {
- if (ShapeUtil::IsToken(subshape)) {
- contains_token = true;
- }
- });
- return contains_token;
-}
+ Status DefaultAction(HloInstruction*) override { return Status::OK(); }
-// Verifies that all types entering and exiting the entry computation are
-// legal.
-Status VerifyEntryAndExitShapes(const HloModule& module) {
- // Tokens cannot be passed as entry parameters.
- // TODO(b/80000000): Remove this constraint.
- for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) {
- HloInstruction* param =
- module.entry_computation()->parameter_instruction(i);
- if (ShapeContainsToken(param->shape())) {
- return InternalError(
- "Entry parameter %d is or contains a token shape: %s", i,
- ShapeUtil::HumanString(param->shape()));
- }
+ Status HandleFusion(HloInstruction* fusion) override {
+ return CheckFusionInstruction(fusion);
}
- return Status::OK();
-}
-// Checks if the given two instructions share the same channel id.
-Status CheckSameChannel(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- if (instr1->channel_id() != instr2->channel_id()) {
- return InternalError(
- "Expected to have the same channel id, actual channel ids are: %s "
- "(%d), %s (%d)",
- instr1->ToString(), instr1->channel_id(), instr2->ToString(),
- instr2->channel_id());
+ Status HandleBroadcast(HloInstruction* broadcast) override {
+ // If you see this failure then someone has confused the difference
+ // between the HLO broadcast op, and the UserComputation broadcast
+ // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
+ // or ComputationLowerer::Visit()
+ TF_RET_CHECK(broadcast->dimensions().size() ==
+ ShapeUtil::Rank(broadcast->operand(0)->shape()))
+ << "Broadcast HLO (" << broadcast->ToShortString()
+ << ") has invalid number of dimensions: "
+ << broadcast->dimensions().size()
+ << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape());
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks if the given two instructions have the same is_host_transfer
-// attribute value. Intsructions must be send/recv instructions or their
-// 'done' variant.
-Status CheckSameIsHostTransfer(const HloInstruction* instr1,
- const HloInstruction* instr2) {
- const HloSendRecvInstruction* send_recv1 =
- DynCast<const HloSendRecvInstruction>(instr1);
- const HloSendRecvInstruction* send_recv2 =
- DynCast<const HloSendRecvInstruction>(instr2);
- TF_RET_CHECK(send_recv1 != nullptr);
- TF_RET_CHECK(send_recv2 != nullptr);
- if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) {
- return InternalError(
- "Expected instructions to have the same is-host-transfer property: "
- "%s, "
- "%s ",
- instr1->ToString(), instr2->ToString());
+ Status HandleWhile(HloInstruction* xla_while) override {
+ auto* while_cond = xla_while->while_condition();
+ auto* while_body = xla_while->while_body();
+ if (while_cond->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While condition must have exactly 1 parameter; had %d : %s",
+ while_cond->num_parameters(), while_cond->ToString());
+ }
+ if (while_body->num_parameters() != 1) {
+ return FailedPrecondition(
+ "While body must have exactly 1 parameter; had %d : %s",
+ while_body->num_parameters(), while_body->ToString());
+ }
+ if (xla_while->operand_count() != 1) {
+ return FailedPrecondition(
+ "While loop must have exactly one operand; had %d : %s",
+ xla_while->operand_count(), xla_while->ToString());
+ }
+ return Status::OK();
}
- return Status::OK();
-}
-// Checks various invariants of send and recv instructions.
-Status VerifySendsAndRecvs(const HloModule& module) {
- absl::flat_hash_map<int64, const HloInstruction*> host_channels;
- // Host send/recv instructions must have their own unique channel.
- auto check_unique_host_channel = [&](const HloInstruction* instruction) {
- const HloSendRecvInstruction* sendrecv =
- DynCast<const HloSendRecvInstruction>(instruction);
- if (sendrecv->is_host_transfer()) {
- auto it_inserted =
- host_channels.insert({sendrecv->channel_id(), sendrecv});
- if (!it_inserted.second) {
- return FailedPrecondition(
- "Channel %d is used for multiple host send/recv instructions: "
- "%s "
- "and "
- "%s",
- sendrecv->channel_id(), sendrecv->ToString(),
- it_inserted.first->second->ToString());
- }
+ Status HandleConditional(HloInstruction* conditional) override {
+ if (conditional->true_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "True computation %s of %s must have 1 parameter insted of %d",
+ conditional->true_computation()->name(), conditional->ToString(),
+ conditional->true_computation()->num_parameters());
}
+ if (conditional->false_computation()->num_parameters() != 1) {
+ return FailedPrecondition(
+ "False computation %s of %s must have 1 parameter insted of %d",
+ conditional->false_computation()->name(), conditional->ToString(),
+ conditional->false_computation()->num_parameters());
+ }
+ return Status::OK();
+ }
+
+ Status HandleElementwiseUnary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+
+ Status HandleElementwiseBinary(HloInstruction* instruction) override {
+ return CheckElementwiseInstruction(instruction);
+ }
+ Status HandleGetTupleElement(HloInstruction* gte) override {
+ TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape()));
return Status::OK();
- };
+ }
- // Send/Recv instruction must have a single user: the corresponding
- // SendDone/RecvDone. with matching channel.
- for (const HloComputation* computation : module.computations()) {
- for (const HloInstruction* instruction : computation->instructions()) {
- switch (instruction->opcode()) {
- case HloOpcode::kSend: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* send_done = instruction->users().front();
- TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done));
- break;
- }
- case HloOpcode::kRecv: {
- TF_RETURN_IF_ERROR(check_unique_host_channel(instruction));
- TF_RET_CHECK(instruction->users().size() == 1);
- const HloInstruction* recv_done = instruction->users().front();
- TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone);
- TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done));
- TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done));
- break;
- }
- case HloOpcode::kSendDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend);
- break;
- case HloOpcode::kRecvDone:
- TF_RET_CHECK(instruction->operands().size() == 1);
- TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv);
- break;
- default:
- break;
- }
- }
+ Status HandleTranspose(HloInstruction* transpose) override {
+ const Shape& shape = transpose->shape();
+ const HloInstruction* operand = transpose->operand(0);
+ TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size());
+ TF_RET_CHECK(shape.dimensions().size() ==
+ transpose->operand(0)->shape().dimensions().size());
+ TF_RET_CHECK(std::equal(
+ operand->shape().dimensions().begin(),
+ operand->shape().dimensions().end(),
+ Permute(transpose->dimensions(), shape.dimensions()).begin()))
+ << "shape: " << shape << ", operand->shape(): " << shape
+ << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ")
+ << "}";
+ return Status::OK();
}
- return Status::OK();
-}
+
+ Status Preprocess(HloInstruction* instruction) override {
+ auto previous = instructions_by_name_.find(instruction->name());
+ TF_RET_CHECK(previous == instructions_by_name_.end())
+ << "HLO has name that is not unique within module:\n"
+ << instruction->ToString()
+ << " in computation: " << instruction->parent()->name()
+ << "\nPrevious HLO with same name:\n"
+ << previous->second->ToString()
+ << " in computation: " << previous->second->parent()->name();
+ instructions_by_name_[instruction->name()] = instruction;
+ return Status::OK();
+ }
+
+ private:
+ absl::flat_hash_map<string, const HloInstruction*> instructions_by_name_;
+};
} // namespace
@@ -1061,65 +1153,13 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
- absl::flat_hash_map<string, const HloInstruction*> instructions;
for (auto* computation : module->computations()) {
- for (const auto& instruction : computation->instructions()) {
- TF_RET_CHECK(instruction->parent() == computation);
- if (instruction->opcode() == HloOpcode::kFusion) {
- TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction));
- TF_RET_CHECK(instruction->called_computations() ==
- absl::Span<HloComputation* const>(
- {instruction->fused_instructions_computation()}))
- << "Fusion HLO calls computations other than the "
- "fused_instructions_computation: "
- << instruction->ToString()
- << " instruction->fused_instructions_computation(): "
- << instruction->fused_instructions_computation()->ToString()
- << " instruction->called_computations(): "
- << ComputationsToString(instruction->called_computations());
-
- for (const auto& fused : instruction->fused_instructions()) {
- TF_RET_CHECK(fused->parent() ==
- instruction->fused_instructions_computation())
- << "Fused HLO was missing a parent: " << fused->ToString()
- << " parent: " << fused->parent()
- << " computation: " << computation;
- }
- } else if (instruction->opcode() == HloOpcode::kBroadcast) {
- // If you see this failure then someone has confused the difference
- // between the HLO broadcast op, and the UserComputation broadcast
- // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I
- // or ComputationLowerer::Visit()
- TF_RET_CHECK(instruction->dimensions().size() ==
- ShapeUtil::Rank(instruction->operand(0)->shape()))
- << "Broadcast HLO (" << instruction->ToShortString()
- << ") has invalid number of dimensions: "
- << instruction->dimensions().size()
- << " != " << ShapeUtil::Rank(instruction->operand(0)->shape());
- } else if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction));
- } else if (instruction->opcode() == HloOpcode::kConditional) {
- TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction));
- } else if (instruction->opcode() !=
- HloOpcode::kRng /* Rng operands are always scalar. */
- && instruction->IsElementwise()) {
- TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction));
- }
-
- auto previous = instructions.find(instruction->name());
- TF_RET_CHECK(previous == instructions.end())
- << "HLO has name that is not unique within module:\n"
- << instruction->ToString()
- << " in computation: " << computation->name()
- << "\nPrevious HLO with same name:\n"
- << previous->second->ToString()
- << " in computation: " << previous->second->parent()->name();
- instructions[instruction->name()] = instruction;
- }
-
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));
+
+ InstructionVerifier instruction_verifier;
+ TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 0cde4a31af..6d16586c2c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -172,17 +172,6 @@ class HloVerifier : public HloModulePass {
StatusOr<bool> Run(HloModule* module) override;
private:
- // CHECKs various invariants of a fusion instruction.
- Status CheckFusionInstruction(HloInstruction* fusion) const;
-
- Status CheckWhileInstruction(HloInstruction* instruction);
-
- Status CheckConditionalInstruction(HloInstruction* instruction);
-
- // Checks that the non-scalar operand shapes are compatible to the output
- // shape, i.e., that there are no implicit broadcasts of size-one dimensions.
- Status CheckElementwiseInstruction(HloInstruction* instruction);
-
// Creates a ShapeVerifier that checks that shapes match inferred
// expectations. This is a factory function because ShapeVerifier,
// being a DfsHloVisitor, is stateful. We want a clean object