From 80c9eec9b2475630f83a596f77a906c8075f8e6c Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Thu, 4 Oct 2018 08:56:45 -0700 Subject: Remove CHECKs from HloInstruction constructors. Move these checks to RET_CHECKs in the HloVerifier. Added a new visitor class InstructionVerifier inside of hlo_verifier.cc for handling these random non-result-shape verifications. PiperOrigin-RevId: 215745043 --- .../compiler/xla/service/hlo_instructions.cc | 12 - tensorflow/compiler/xla/service/hlo_instructions.h | 1 - tensorflow/compiler/xla/service/hlo_verifier.cc | 456 +++++++++++---------- tensorflow/compiler/xla/service/hlo_verifier.h | 11 - 4 files changed, 248 insertions(+), 232 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 68d0979f5c..152d8eacdb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -643,14 +643,6 @@ HloTransposeInstruction::HloTransposeInstruction( absl::Span dimensions) : HloInstruction(HloOpcode::kTranspose, shape), dimensions_(dimensions.begin(), dimensions.end()) { - CHECK_EQ(shape.dimensions().size(), dimensions.size()); - CHECK_EQ(shape.dimensions().size(), operand->shape().dimensions().size()); - CHECK(std::equal(operand->shape().dimensions().begin(), - operand->shape().dimensions().end(), - Permute(dimensions, shape.dimensions()).begin())) - << "shape: " << ShapeUtil::HumanString(shape) - << ", operand->shape(): " << ShapeUtil::HumanString(shape) - << ", dimensions: {" << StrJoin(dimensions, ", ") << "}"; AppendOperand(operand); } @@ -1491,7 +1483,6 @@ HloParameterInstruction::CloneWithNewOperandsImpl( HloGetTupleElementInstruction::HloGetTupleElementInstruction( const Shape& shape, HloInstruction* operand, int64 index) : HloInstruction(HloOpcode::kGetTupleElement, shape), tuple_index_(index) { - CHECK(ShapeUtil::IsTuple(operand->shape())); AppendOperand(operand); } @@ -1613,9 +1604,6 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape, : HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeTokenShape()), outfeed_shape_(outfeed_shape), outfeed_config_(outfeed_config) { - CHECK(ShapeUtil::Compatible(operand->shape(), outfeed_shape)) - << "Outfeed shape " << outfeed_shape - << " must be compatible with operand shape " << operand->shape(); AppendOperand(operand); AppendOperand(token_operand); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index ab168800f6..e169604072 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -896,7 +896,6 @@ class HloOutfeedInstruction : public HloInstruction { absl::string_view outfeed_config); // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const { - TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape_)); return outfeed_shape_; } // Returns the config for the Outfeed instruction. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index a7727824fe..b5498bb936 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -763,7 +763,136 @@ Status VerifyHloStructure(HloModule* module) { return Status::OK(); } -Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { +namespace { + +// Returns true if the given Shape has a TOKEN shape as any subshape. +bool ShapeContainsToken(const Shape& shape) { + bool contains_token = false; + ShapeUtil::ForEachSubshape( + shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsToken(subshape)) { + contains_token = true; + } + }); + return contains_token; +} + +// Verifies that all types entering and exiting the entry computation are +// legal. +Status VerifyEntryAndExitShapes(const HloModule& module) { + // Tokens cannot be passed as entry parameters. + // TODO(b/80000000): Remove this constraint. + for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { + HloInstruction* param = + module.entry_computation()->parameter_instruction(i); + if (ShapeContainsToken(param->shape())) { + return InternalError( + "Entry parameter %d is or contains a token shape: %s", i, + ShapeUtil::HumanString(param->shape())); + } + } + return Status::OK(); +} + +// Checks if the given two instructions share the same channel id. +Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return InternalError( + "Expected to have the same channel id, actual channel ids are: %s " + "(%d), %s (%d)", + instr1->ToString(), instr1->channel_id(), instr2->ToString(), + instr2->channel_id()); + } + return Status::OK(); +} + +// Checks if the given two instructions have the same is_host_transfer +// attribute value. Intsructions must be send/recv instructions or their +// 'done' variant. +Status CheckSameIsHostTransfer(const HloInstruction* instr1, + const HloInstruction* instr2) { + const HloSendRecvInstruction* send_recv1 = + DynCast(instr1); + const HloSendRecvInstruction* send_recv2 = + DynCast(instr2); + TF_RET_CHECK(send_recv1 != nullptr); + TF_RET_CHECK(send_recv2 != nullptr); + if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { + return InternalError( + "Expected instructions to have the same is-host-transfer property: " + "%s, " + "%s ", + instr1->ToString(), instr2->ToString()); + } + return Status::OK(); +} + +// Checks various invariants of send and recv instructions. +Status VerifySendsAndRecvs(const HloModule& module) { + absl::flat_hash_map host_channels; + // Host send/recv instructions must have their own unique channel. + auto check_unique_host_channel = [&](const HloInstruction* instruction) { + const HloSendRecvInstruction* sendrecv = + DynCast(instruction); + if (sendrecv->is_host_transfer()) { + auto it_inserted = + host_channels.insert({sendrecv->channel_id(), sendrecv}); + if (!it_inserted.second) { + return FailedPrecondition( + "Channel %d is used for multiple host send/recv instructions: " + "%s " + "and " + "%s", + sendrecv->channel_id(), sendrecv->ToString(), + it_inserted.first->second->ToString()); + } + } + + return Status::OK(); + }; + + // Send/Recv instruction must have a single user: the corresponding + // SendDone/RecvDone. with matching channel. + for (const HloComputation* computation : module.computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + switch (instruction->opcode()) { + case HloOpcode::kSend: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* send_done = instruction->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); + break; + } + case HloOpcode::kRecv: { + TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); + TF_RET_CHECK(instruction->users().size() == 1); + const HloInstruction* recv_done = instruction->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); + TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); + break; + } + case HloOpcode::kSendDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); + break; + case HloOpcode::kRecvDone: + TF_RET_CHECK(instruction->operands().size() == 1); + TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); + break; + default: + break; + } + } + } + return Status::OK(); +} + +// CHECKs various invariants of a fusion instruction. +Status CheckFusionInstruction(HloInstruction* fusion) { // The parent fusion instruction of the fusion computation must be 'fusion'. HloComputation* fused_computation = fusion->fused_instructions_computation(); if (fusion != fused_computation->FusionInstruction()) { @@ -866,50 +995,32 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } } + TF_RET_CHECK(fusion->called_computations() == + absl::Span( + {fusion->fused_instructions_computation()})) + << "Fusion HLO calls computations other than the " + "fused_instructions_computation: " + << fusion->ToString() << " fusion->fused_instructions_computation(): " + << fusion->fused_instructions_computation()->ToString() + << " fusion->called_computations(): " + << ComputationsToString(fusion->called_computations()); + + for (const auto& fused : fusion->fused_instructions()) { + TF_RET_CHECK(fused->parent() == fusion->fused_instructions_computation()) + << "Fused HLO was missing a parent: " << fused->ToString() + << " parent: " << fused->parent() + << " computation: " << fusion->parent(); + } + // TODO(b/65423525): We'd like to check that all operands are distinct. // This is currently disabled due to the invariant being violated by // multi-output fusion. return Status::OK(); } -Status HloVerifier::CheckWhileInstruction(HloInstruction* instruction) { - auto* while_cond = instruction->while_condition(); - auto* while_body = instruction->while_body(); - if (while_cond->num_parameters() != 1) { - return FailedPrecondition( - "While condition must have exactly 1 parameter; had %d : %s", - while_cond->num_parameters(), while_cond->ToString()); - } - if (while_body->num_parameters() != 1) { - return FailedPrecondition( - "While body must have exactly 1 parameter; had %d : %s", - while_body->num_parameters(), while_body->ToString()); - } - if (instruction->operand_count() != 1) { - return FailedPrecondition( - "While loop must have exactly one operand; had %d : %s", - instruction->operand_count(), instruction->ToString()); - } - return Status::OK(); -} - -Status HloVerifier::CheckConditionalInstruction(HloInstruction* instruction) { - if (instruction->true_computation()->num_parameters() != 1) { - return FailedPrecondition( - "True computation %s of %s must have 1 parameter insted of %d", - instruction->true_computation()->name(), instruction->ToString(), - instruction->true_computation()->num_parameters()); - } - if (instruction->false_computation()->num_parameters() != 1) { - return FailedPrecondition( - "False computation %s of %s must have 1 parameter insted of %d", - instruction->false_computation()->name(), instruction->ToString(), - instruction->false_computation()->num_parameters()); - } - return Status::OK(); -} - -Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { +// Checks that the non-scalar operand shapes are compatible to the output +// shape, i.e., that there are no implicit broadcasts of size-one dimensions. +Status CheckElementwiseInstruction(HloInstruction* instruction) { const Shape& out_shape = instruction->shape(); for (HloInstruction* operand : instruction->operands()) { const Shape& operand_shape = operand->shape(); @@ -926,133 +1037,114 @@ Status HloVerifier::CheckElementwiseInstruction(HloInstruction* instruction) { return Status::OK(); } -namespace { +// Visitor which verifies various fields on the HLO instruction. This class does +// not check result shape as that is checked in the ShapeVerifier. +class InstructionVerifier : public DfsHloVisitorWithDefault { + public: + InstructionVerifier() {} -// Returns true if the given Shape has a TOKEN shape as any subshape. -bool ShapeContainsToken(const Shape& shape) { - bool contains_token = false; - ShapeUtil::ForEachSubshape( - shape, [&contains_token](const Shape& subshape, const ShapeIndex&) { - if (ShapeUtil::IsToken(subshape)) { - contains_token = true; - } - }); - return contains_token; -} + Status DefaultAction(HloInstruction*) override { return Status::OK(); } -// Verifies that all types entering and exiting the entry computation are -// legal. -Status VerifyEntryAndExitShapes(const HloModule& module) { - // Tokens cannot be passed as entry parameters. - // TODO(b/80000000): Remove this constraint. - for (int i = 0; i < module.entry_computation()->num_parameters(); ++i) { - HloInstruction* param = - module.entry_computation()->parameter_instruction(i); - if (ShapeContainsToken(param->shape())) { - return InternalError( - "Entry parameter %d is or contains a token shape: %s", i, - ShapeUtil::HumanString(param->shape())); - } + Status HandleFusion(HloInstruction* fusion) override { + return CheckFusionInstruction(fusion); } - return Status::OK(); -} -// Checks if the given two instructions share the same channel id. -Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return InternalError( - "Expected to have the same channel id, actual channel ids are: %s " - "(%d), %s (%d)", - instr1->ToString(), instr1->channel_id(), instr2->ToString(), - instr2->channel_id()); + Status HandleBroadcast(HloInstruction* broadcast) override { + // If you see this failure then someone has confused the difference + // between the HLO broadcast op, and the UserComputation broadcast + // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I + // or ComputationLowerer::Visit() + TF_RET_CHECK(broadcast->dimensions().size() == + ShapeUtil::Rank(broadcast->operand(0)->shape())) + << "Broadcast HLO (" << broadcast->ToShortString() + << ") has invalid number of dimensions: " + << broadcast->dimensions().size() + << " != " << ShapeUtil::Rank(broadcast->operand(0)->shape()); + return Status::OK(); } - return Status::OK(); -} -// Checks if the given two instructions have the same is_host_transfer -// attribute value. Intsructions must be send/recv instructions or their -// 'done' variant. -Status CheckSameIsHostTransfer(const HloInstruction* instr1, - const HloInstruction* instr2) { - const HloSendRecvInstruction* send_recv1 = - DynCast(instr1); - const HloSendRecvInstruction* send_recv2 = - DynCast(instr2); - TF_RET_CHECK(send_recv1 != nullptr); - TF_RET_CHECK(send_recv2 != nullptr); - if (send_recv1->is_host_transfer() != send_recv2->is_host_transfer()) { - return InternalError( - "Expected instructions to have the same is-host-transfer property: " - "%s, " - "%s ", - instr1->ToString(), instr2->ToString()); + Status HandleWhile(HloInstruction* xla_while) override { + auto* while_cond = xla_while->while_condition(); + auto* while_body = xla_while->while_body(); + if (while_cond->num_parameters() != 1) { + return FailedPrecondition( + "While condition must have exactly 1 parameter; had %d : %s", + while_cond->num_parameters(), while_cond->ToString()); + } + if (while_body->num_parameters() != 1) { + return FailedPrecondition( + "While body must have exactly 1 parameter; had %d : %s", + while_body->num_parameters(), while_body->ToString()); + } + if (xla_while->operand_count() != 1) { + return FailedPrecondition( + "While loop must have exactly one operand; had %d : %s", + xla_while->operand_count(), xla_while->ToString()); + } + return Status::OK(); } - return Status::OK(); -} -// Checks various invariants of send and recv instructions. -Status VerifySendsAndRecvs(const HloModule& module) { - absl::flat_hash_map host_channels; - // Host send/recv instructions must have their own unique channel. - auto check_unique_host_channel = [&](const HloInstruction* instruction) { - const HloSendRecvInstruction* sendrecv = - DynCast(instruction); - if (sendrecv->is_host_transfer()) { - auto it_inserted = - host_channels.insert({sendrecv->channel_id(), sendrecv}); - if (!it_inserted.second) { - return FailedPrecondition( - "Channel %d is used for multiple host send/recv instructions: " - "%s " - "and " - "%s", - sendrecv->channel_id(), sendrecv->ToString(), - it_inserted.first->second->ToString()); - } + Status HandleConditional(HloInstruction* conditional) override { + if (conditional->true_computation()->num_parameters() != 1) { + return FailedPrecondition( + "True computation %s of %s must have 1 parameter insted of %d", + conditional->true_computation()->name(), conditional->ToString(), + conditional->true_computation()->num_parameters()); } + if (conditional->false_computation()->num_parameters() != 1) { + return FailedPrecondition( + "False computation %s of %s must have 1 parameter insted of %d", + conditional->false_computation()->name(), conditional->ToString(), + conditional->false_computation()->num_parameters()); + } + return Status::OK(); + } + + Status HandleElementwiseUnary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + + Status HandleElementwiseBinary(HloInstruction* instruction) override { + return CheckElementwiseInstruction(instruction); + } + Status HandleGetTupleElement(HloInstruction* gte) override { + TF_RET_CHECK(ShapeUtil::IsTuple(gte->operand(0)->shape())); return Status::OK(); - }; + } - // Send/Recv instruction must have a single user: the corresponding - // SendDone/RecvDone. with matching channel. - for (const HloComputation* computation : module.computations()) { - for (const HloInstruction* instruction : computation->instructions()) { - switch (instruction->opcode()) { - case HloOpcode::kSend: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* send_done = instruction->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, send_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, send_done)); - break; - } - case HloOpcode::kRecv: { - TF_RETURN_IF_ERROR(check_unique_host_channel(instruction)); - TF_RET_CHECK(instruction->users().size() == 1); - const HloInstruction* recv_done = instruction->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(instruction, recv_done)); - TF_RETURN_IF_ERROR(CheckSameIsHostTransfer(instruction, recv_done)); - break; - } - case HloOpcode::kSendDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kSend); - break; - case HloOpcode::kRecvDone: - TF_RET_CHECK(instruction->operands().size() == 1); - TF_RET_CHECK(instruction->operand(0)->opcode() == HloOpcode::kRecv); - break; - default: - break; - } - } + Status HandleTranspose(HloInstruction* transpose) override { + const Shape& shape = transpose->shape(); + const HloInstruction* operand = transpose->operand(0); + TF_RET_CHECK(shape.dimensions().size() == transpose->dimensions().size()); + TF_RET_CHECK(shape.dimensions().size() == + transpose->operand(0)->shape().dimensions().size()); + TF_RET_CHECK(std::equal( + operand->shape().dimensions().begin(), + operand->shape().dimensions().end(), + Permute(transpose->dimensions(), shape.dimensions()).begin())) + << "shape: " << shape << ", operand->shape(): " << shape + << ", dimensions: {" << absl::StrJoin(transpose->dimensions(), ", ") + << "}"; + return Status::OK(); } - return Status::OK(); -} + + Status Preprocess(HloInstruction* instruction) override { + auto previous = instructions_by_name_.find(instruction->name()); + TF_RET_CHECK(previous == instructions_by_name_.end()) + << "HLO has name that is not unique within module:\n" + << instruction->ToString() + << " in computation: " << instruction->parent()->name() + << "\nPrevious HLO with same name:\n" + << previous->second->ToString() + << " in computation: " << previous->second->parent()->name(); + instructions_by_name_[instruction->name()] = instruction; + return Status::OK(); + } + + private: + absl::flat_hash_map instructions_by_name_; +}; } // namespace @@ -1061,65 +1153,13 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(VerifyHloStructure(module)); TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module)); - absl::flat_hash_map instructions; for (auto* computation : module->computations()) { - for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation); - if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); - TF_RET_CHECK(instruction->called_computations() == - absl::Span( - {instruction->fused_instructions_computation()})) - << "Fusion HLO calls computations other than the " - "fused_instructions_computation: " - << instruction->ToString() - << " instruction->fused_instructions_computation(): " - << instruction->fused_instructions_computation()->ToString() - << " instruction->called_computations(): " - << ComputationsToString(instruction->called_computations()); - - for (const auto& fused : instruction->fused_instructions()) { - TF_RET_CHECK(fused->parent() == - instruction->fused_instructions_computation()) - << "Fused HLO was missing a parent: " << fused->ToString() - << " parent: " << fused->parent() - << " computation: " << computation; - } - } else if (instruction->opcode() == HloOpcode::kBroadcast) { - // If you see this failure then someone has confused the difference - // between the HLO broadcast op, and the UserComputation broadcast - // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I - // or ComputationLowerer::Visit() - TF_RET_CHECK(instruction->dimensions().size() == - ShapeUtil::Rank(instruction->operand(0)->shape())) - << "Broadcast HLO (" << instruction->ToShortString() - << ") has invalid number of dimensions: " - << instruction->dimensions().size() - << " != " << ShapeUtil::Rank(instruction->operand(0)->shape()); - } else if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(CheckWhileInstruction(instruction)); - } else if (instruction->opcode() == HloOpcode::kConditional) { - TF_RETURN_IF_ERROR(CheckConditionalInstruction(instruction)); - } else if (instruction->opcode() != - HloOpcode::kRng /* Rng operands are always scalar. */ - && instruction->IsElementwise()) { - TF_RETURN_IF_ERROR(CheckElementwiseInstruction(instruction)); - } - - auto previous = instructions.find(instruction->name()); - TF_RET_CHECK(previous == instructions.end()) - << "HLO has name that is not unique within module:\n" - << instruction->ToString() - << " in computation: " << computation->name() - << "\nPrevious HLO with same name:\n" - << previous->second->ToString() - << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction; - } - std::unique_ptr shape_verifier = shape_verifier_factory_(); TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); + + InstructionVerifier instruction_verifier; + TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 0cde4a31af..6d16586c2c 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -172,17 +172,6 @@ class HloVerifier : public HloModulePass { StatusOr Run(HloModule* module) override; private: - // CHECKs various invariants of a fusion instruction. - Status CheckFusionInstruction(HloInstruction* fusion) const; - - Status CheckWhileInstruction(HloInstruction* instruction); - - Status CheckConditionalInstruction(HloInstruction* instruction); - - // Checks that the non-scalar operand shapes are compatible to the output - // shape, i.e., that there are no implicit broadcasts of size-one dimensions. - Status CheckElementwiseInstruction(HloInstruction* instruction); - // Creates a ShapeVerifier that checks that shapes match inferred // expectations. This is a factory function because ShapeVerifier, // being a DfsHloVisitor, is stateful. We want a clean object -- cgit v1.2.3