diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-09 16:52:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 16:59:36 -0700 |
commit | d4526cf9d1d58cbe480e7d2b8199620e0e9f0572 (patch) | |
tree | 70fb212352f18cc5b0589fc9e9b20bdadf831c87 | |
parent | c770568935b85d506dc1a1f671822a7e122b5056 (diff) |
[XLA] Added xla::CreateModuleFromProto(...) combining loading module
from proto and verifying it with HloVerifier.
PiperOrigin-RevId: 216447947
-rw-r--r-- | tensorflow/compiler/xla/layout_util.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.cc | 12 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_proto_util.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_verifier.cc | 104 |
6 files changed, 132 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 3c8db9aa45..19667b7ed9 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -205,7 +205,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return Status::OK(); } - if (layout.format() == INVALID_FORMAT) { + if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) { return InvalidArgument( "Layout does not have a valid format: layout {%s}, shape {%s}", layout.ShortDebugString(), shape.ShortDebugString()); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 2b292ed053..f9f741aaee 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3127,6 +3127,7 @@ cc_library( ":buffer_assignment", ":hlo", ":hlo_proto", + ":hlo_verifier", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:util", ], diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 050d28b289..09bcf8a9e7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -305,6 +305,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.tuple_index()); break; case HloOpcode::kReducePrecision: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); @@ -312,12 +315,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), @@ -349,6 +356,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector<std::pair<int64, int64>> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index b9c0b0c4ee..026a0e8fba 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/hlo_proto_util.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" #include <string> @@ -36,6 +37,17 @@ HloProto MakeHloProto(const HloModule& module) { return proto; } +StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config) { + TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module, + HloModule::CreateFromProto(proto, module_config)); + TF_RETURN_IF_ERROR( + HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false) + .Run(module.get()) + .status()); + return std::move(module); +} + StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes( const HloProto& hlo_proto) { if (!hlo_proto.has_hlo_module()) { diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 3d9c375cd5..1db82dd6fc 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -35,6 +35,12 @@ HloProto MakeHloProto(const HloModule& module, // will not be included in the output. HloProto MakeHloProto(const HloModule& module); +// Create an HLO state from serialized representation. In addition to +// creating the proto with HloModule::CreateFromProto(...) it also +// uses HloVerifier to ensure basic invariants are held. +StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config); + // Returns the shapes of the parameters of the entry computation. Shape pointers // refer to shapes inside of the given HloProto. StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes( diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index be3bee5975..620458855f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -27,6 +27,15 @@ limitations under the License. namespace xla { +static Status CheckOperandCount(const HloInstruction* hlo, int expected) { + if (hlo->operand_count() != expected) { + return InternalError("Expected %d operands for %s instruction: %s", + expected, HloOpcodeString(hlo->opcode()), + hlo->ToString()); + } + return Status::OK(); +} + Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { return CheckUnaryShape(hlo); } @@ -58,12 +67,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { } Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); } Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1)); return CheckShape(convert, ShapeInference::InferBitcastConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -74,6 +85,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) { } Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2)); TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferDotOpShape( dot->operand(0)->shape(), dot->operand(1)->shape(), @@ -82,6 +94,7 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) { } Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferConvolveShape( @@ -92,6 +105,7 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { } Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1)); TF_ASSIGN_OR_RETURN( const Shape expected, ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), @@ -118,11 +132,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) { } Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape( hlo->operand(0)->shape())); } Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1)); return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( reduce_precision->operand(0)->shape(), reduce_precision->exponent_bits(), @@ -156,6 +172,7 @@ Status ShapeVerifier::CheckOperandAndParameter( } Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0)); @@ -166,6 +183,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) { } Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction); TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1)); @@ -192,10 +210,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0, } Status ShapeVerifier::HandleRng(HloInstruction* instruction) { - if (instruction->operand_count() != 2) { - return InternalError("Expected two operands for Rng instruction: %s", - instruction->ToString()); - } + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); const Shape& shape_0 = instruction->operand(0)->shape(); const Shape& shape_1 = instruction->operand(1)->shape(); @@ -244,12 +259,17 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) { } Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1)); return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } Status ShapeVerifier::HandleSort(HloInstruction* sort) { + if (sort->operand_count() < 1 || sort->operand_count() > 2) { + return InternalError("Expected 1 or 2 operands for %s instruction: %s", + HloOpcodeString(sort->opcode()), sort->ToString()); + } if (sort->operand_count() == 2 && !ShapeUtil::SameDimensions(sort->operand(0)->shape(), sort->operand(1)->shape())) { @@ -263,10 +283,12 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) { } Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0)); return CheckShape(constant, constant->literal().shape()); } Status ShapeVerifier::HandleIota(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0)); auto* iota = Cast<HloIotaInstruction>(instruction); const int64 rank = ShapeUtil::Rank(iota->shape()); if (rank == 0) { @@ -281,6 +303,7 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) { } Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1)); return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), @@ -288,6 +311,12 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { } Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + if (reduce->operand_count() % 2 != 0) { + return InternalError( + "Expected an even number of operands for %s instruction: %s", + HloOpcodeString(reduce->opcode()), reduce->ToString()); + } + std::vector<const Shape*> operand_shapes; for (const HloInstruction* operand : reduce->operands()) { operand_shapes.push_back(&operand->shape()); @@ -298,10 +327,12 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { } Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1)); return Status::OK(); } Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1)); // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); @@ -322,6 +353,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1)); // Check for mixed precision. TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == @@ -330,12 +362,14 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { } Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1)); return CheckShape( transpose, ShapeInference::InferTransposeShape( transpose->operand(0)->shape(), transpose->dimensions())); } Status ShapeVerifier::HandleParameter(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0)); return Status::OK(); } @@ -383,6 +417,7 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { } Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1)); return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), @@ -390,6 +425,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) { } Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2)); return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), @@ -398,6 +434,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { Status ShapeVerifier::HandleDynamicUpdateSlice( HloInstruction* dynamic_update_slice) { + TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3)); return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -427,6 +464,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) { } Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2)); return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -436,6 +474,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { } Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape( instruction, ShapeInference::InferSelectAndScatterShape( @@ -446,6 +485,7 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { } Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1)); TF_RETURN_IF_ERROR( CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0)); TF_RETURN_IF_ERROR( @@ -465,6 +505,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { } Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( conditional, 1, conditional->true_computation(), 0)); TF_RETURN_IF_ERROR(CheckOperandAndParameter( @@ -479,12 +520,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { } Status ShapeVerifier::HandlePad(HloInstruction* pad) { + TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2)); return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), pad->operand(1)->shape(), pad->padding_config())); } Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RETURN_IF_ERROR(CheckOperandCount(send, 2)); return CheckShape(send, ShapeUtil::MakeTupleShape({send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {}), @@ -492,10 +535,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) { } Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1)); return CheckShape(send_done, ShapeUtil::MakeTokenShape()); } Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1)); return CheckShape( recv, ShapeUtil::MakeTupleShape( {ShapeUtil::GetTupleElementShape(recv->shape(), 0), @@ -503,6 +548,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) { } Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1)); return CheckShape( recv_done, ShapeUtil::MakeTupleShape( @@ -512,6 +558,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { Status ShapeVerifier::HandleBatchNormTraining( HloInstruction* batch_norm_training) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3)); return CheckShape(batch_norm_training, ShapeInference::InferBatchNormTrainingShape( batch_norm_training->operand(0)->shape(), @@ -522,6 +569,7 @@ Status ShapeVerifier::HandleBatchNormTraining( Status ShapeVerifier::HandleBatchNormInference( HloInstruction* batch_norm_inference) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5)); return CheckShape(batch_norm_inference, ShapeInference::InferBatchNormInferenceShape( batch_norm_inference->operand(0)->shape(), @@ -533,6 +581,7 @@ Status ShapeVerifier::HandleBatchNormInference( } Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5)); return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( batch_norm_grad->operand(0)->shape(), batch_norm_grad->operand(1)->shape(), @@ -601,6 +650,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { } // namespace Status ShapeVerifier::HandleGather(HloInstruction* gather) { + TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2)); return CheckShape( gather, ShapeInference::InferGatherShape( @@ -609,6 +659,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) { } Status ShapeVerifier::HandleScatter(HloInstruction* scatter) { + TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3)); return CheckShape( scatter, ShapeInference::InferScatterShape( scatter->operand(0)->shape(), scatter->operand(1)->shape(), @@ -696,12 +747,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1)); return CheckShape(instruction, ShapeInference::InferUnaryOpShape(instruction->opcode(), instruction->operand(0))); } Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2)); return CheckShape( instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), instruction->operand(0), @@ -709,6 +762,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { } Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3)); return CheckShape(instruction, ShapeInference::InferTernaryOpShape( instruction->opcode(), instruction->operand(0), @@ -816,6 +870,47 @@ Status VerifyEntryAndExitShapes(const HloModule& module) { return Status::OK(); } +// Verifies that entry computation layout matches characteristics of +// entry computation. +Status CheckEntryComputationLayout(const HloModule& module) { + const HloComputation* computation = module.entry_computation(); + const auto& layout = module.entry_computation_layout(); + + // TODO(117498192): Change into a call to Compatible(...). + if (!ShapeUtil::CompatibleIgnoringFpPrecision( + computation->root_instruction()->shape(), + layout.result_layout().shape())) { + return InternalError( + "Shape of the root instruction of entry computation (%s) should be " + "compatible to one specified in module's entry computation layout (%s)", + ShapeUtil::HumanString(computation->root_instruction()->shape()), + ShapeUtil::HumanString(layout.result_layout().shape())); + } + + if (computation->num_parameters() != layout.parameter_count()) { + return InternalError( + "Number of parameters in entry computation layout (%d) must be same " + "as number of parameters of entry computation computation (%d)", + layout.parameter_count(), computation->num_parameters()); + } + + for (int i = 0; i < computation->num_parameters(); ++i) { + if (!ShapeUtil::Compatible(computation->parameter_instruction(i)->shape(), + layout.parameter_shape(i))) { + return InternalError( + "Shape of the entry computation parameter %d is %s should be " + "compatible to the one specified in module's entry computation " + "layout %s", + i, + ShapeUtil::HumanString( + computation->parameter_instruction(i)->shape()), + ShapeUtil::HumanString(layout.parameter_shape(i))); + } + } + + return Status::OK(); +} + // Checks if the given two instructions share the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2) { @@ -1213,6 +1308,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) { TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier)); } + TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module)); TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module)); // If the module has a schedule, it must be valid. |