aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 16:52:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 16:59:36 -0700
commitd4526cf9d1d58cbe480e7d2b8199620e0e9f0572 (patch)
tree70fb212352f18cc5b0589fc9e9b20bdadf831c87
parentc770568935b85d506dc1a1f671822a7e122b5056 (diff)
[XLA] Added xla::CreateModuleFromProto(...) combining loading module
from proto and verifying it with HloVerifier. PiperOrigin-RevId: 216447947
-rw-r--r--tensorflow/compiler/xla/layout_util.cc2
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_proto_util.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc104
6 files changed, 132 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index 3c8db9aa45..19667b7ed9 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -205,7 +205,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return Status::OK();
}
- if (layout.format() == INVALID_FORMAT) {
+ if (layout.format() == INVALID_FORMAT || !Format_IsValid(layout.format())) {
return InvalidArgument(
"Layout does not have a valid format: layout {%s}, shape {%s}",
layout.ShortDebugString(), shape.ShortDebugString());
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 2b292ed053..f9f741aaee 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -3127,6 +3127,7 @@ cc_library(
":buffer_assignment",
":hlo",
":hlo_proto",
+ ":hlo_verifier",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:util",
],
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 050d28b289..09bcf8a9e7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -305,6 +305,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.tuple_index());
break;
case HloOpcode::kReducePrecision:
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "ReducePrecision instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction =
CreateReducePrecision(proto.shape(), operands(0),
proto.exponent_bits(), proto.mantissa_bits());
@@ -312,12 +315,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kInfeed: {
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
- TF_RET_CHECK(proto.operand_ids_size() == 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Infeed instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed:
- TF_RET_CHECK(proto.operand_ids_size() == 2);
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Outfeed instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
@@ -349,6 +356,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kCollectivePermute: {
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "CollectivePermute instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
std::vector<std::pair<int64, int64>> source_target_pairs(
proto.source_target_pairs_size());
for (int i = 0; i < source_target_pairs.size(); i++) {
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc
index b9c0b0c4ee..026a0e8fba 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_proto_util.h"
+#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include <string>
@@ -36,6 +37,17 @@ HloProto MakeHloProto(const HloModule& module) {
return proto;
}
+StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
+ const HloModuleProto& proto, const HloModuleConfig& module_config) {
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
+ HloModule::CreateFromProto(proto, module_config));
+ TF_RETURN_IF_ERROR(
+ HloVerifier(/*layout_sensitive=*/true, /*allow_mixed_precision=*/false)
+ .Run(module.get())
+ .status());
+ return std::move(module);
+}
+
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
const HloProto& hlo_proto) {
if (!hlo_proto.has_hlo_module()) {
diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h
index 3d9c375cd5..1db82dd6fc 100644
--- a/tensorflow/compiler/xla/service/hlo_proto_util.h
+++ b/tensorflow/compiler/xla/service/hlo_proto_util.h
@@ -35,6 +35,12 @@ HloProto MakeHloProto(const HloModule& module,
// will not be included in the output.
HloProto MakeHloProto(const HloModule& module);
+// Create an HLO state from serialized representation. In addition to
+// creating the proto with HloModule::CreateFromProto(...) it also
+// uses HloVerifier to ensure basic invariants are held.
+StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
+ const HloModuleProto& proto, const HloModuleConfig& module_config);
+
// Returns the shapes of the parameters of the entry computation. Shape pointers
// refer to shapes inside of the given HloProto.
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index be3bee5975..620458855f 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -27,6 +27,15 @@ limitations under the License.
namespace xla {
+static Status CheckOperandCount(const HloInstruction* hlo, int expected) {
+ if (hlo->operand_count() != expected) {
+ return InternalError("Expected %d operands for %s instruction: %s",
+ expected, HloOpcodeString(hlo->opcode()),
+ hlo->ToString());
+ }
+ return Status::OK();
+}
+
Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) {
return CheckUnaryShape(hlo);
}
@@ -58,12 +67,14 @@ Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) {
}
Status ShapeVerifier::HandleConvert(HloInstruction* convert) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1));
return CheckShape(convert, ShapeInference::InferConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
}
Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(convert, 1));
return CheckShape(convert, ShapeInference::InferBitcastConvertShape(
convert->operand(0)->shape(),
convert->shape().element_type()));
@@ -74,6 +85,7 @@ Status ShapeVerifier::HandleCopy(HloInstruction* copy) {
}
Status ShapeVerifier::HandleDot(HloInstruction* dot) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(dot, 2));
TF_ASSIGN_OR_RETURN(const Shape expected,
ShapeInference::InferDotOpShape(
dot->operand(0)->shape(), dot->operand(1)->shape(),
@@ -82,6 +94,7 @@ Status ShapeVerifier::HandleDot(HloInstruction* dot) {
}
Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(convolution, 2));
TF_ASSIGN_OR_RETURN(
const Shape expected,
ShapeInference::InferConvolveShape(
@@ -92,6 +105,7 @@ Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) {
}
Status ShapeVerifier::HandleFft(HloInstruction* fft) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(fft, 1));
TF_ASSIGN_OR_RETURN(
const Shape expected,
ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(),
@@ -118,11 +132,13 @@ Status ShapeVerifier::HandleAllToAll(HloInstruction* hlo) {
}
Status ShapeVerifier::HandleCollectivePermute(HloInstruction* hlo) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
return CheckShape(hlo, ShapeInference::InferCollectivePermuteShape(
hlo->operand(0)->shape()));
}
Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(reduce_precision, 1));
return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape(
reduce_precision->operand(0)->shape(),
reduce_precision->exponent_bits(),
@@ -156,6 +172,7 @@ Status ShapeVerifier::CheckOperandAndParameter(
}
Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1));
HloInfeedInstruction* infeed = Cast<HloInfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 0));
@@ -166,6 +183,7 @@ Status ShapeVerifier::HandleInfeed(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
HloOutfeedInstruction* outfeed = Cast<HloOutfeedInstruction>(instruction);
TF_RETURN_IF_ERROR(CheckIsTokenOperand(instruction, 1));
@@ -192,10 +210,7 @@ bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
}
Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
- if (instruction->operand_count() != 2) {
- return InternalError("Expected two operands for Rng instruction: %s",
- instruction->ToString());
- }
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
const Shape& shape_0 = instruction->operand(0)->shape();
const Shape& shape_1 = instruction->operand(1)->shape();
@@ -244,12 +259,17 @@ Status ShapeVerifier::HandleRng(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleReverse(HloInstruction* reverse) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(reverse, 1));
return CheckShape(
reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(),
reverse->dimensions()));
}
Status ShapeVerifier::HandleSort(HloInstruction* sort) {
+ if (sort->operand_count() < 1 || sort->operand_count() > 2) {
+ return InternalError("Expected 1 or 2 operands for %s instruction: %s",
+ HloOpcodeString(sort->opcode()), sort->ToString());
+ }
if (sort->operand_count() == 2 &&
!ShapeUtil::SameDimensions(sort->operand(0)->shape(),
sort->operand(1)->shape())) {
@@ -263,10 +283,12 @@ Status ShapeVerifier::HandleSort(HloInstruction* sort) {
}
Status ShapeVerifier::HandleConstant(HloInstruction* constant) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(constant, 0));
return CheckShape(constant, constant->literal().shape());
}
Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 0));
auto* iota = Cast<HloIotaInstruction>(instruction);
const int64 rank = ShapeUtil::Rank(iota->shape());
if (rank == 0) {
@@ -281,6 +303,7 @@ Status ShapeVerifier::HandleIota(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(get_tuple_element, 1));
return CheckShape(get_tuple_element,
ShapeInference::InferGetTupleElementShape(
get_tuple_element->operand(0)->shape(),
@@ -288,6 +311,12 @@ Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) {
}
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
+ if (reduce->operand_count() % 2 != 0) {
+ return InternalError(
+ "Expected an even number of operands for %s instruction: %s",
+ HloOpcodeString(reduce->opcode()), reduce->ToString());
+ }
+
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : reduce->operands()) {
operand_shapes.push_back(&operand->shape());
@@ -298,10 +327,12 @@ Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(bitcast, 1));
return Status::OK();
}
Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(broadcast, 1));
// HLO broadcast has no exact analog at the proto level so there is no
// ShapeInference method. Check the output shape explicitly.
const Shape& operand_shape = broadcast->operand(0)->shape();
@@ -322,6 +353,7 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(reshape, 1));
// Check for mixed precision.
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
@@ -330,12 +362,14 @@ Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
}
Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(transpose, 1));
return CheckShape(
transpose, ShapeInference::InferTransposeShape(
transpose->operand(0)->shape(), transpose->dimensions()));
}
Status ShapeVerifier::HandleParameter(HloInstruction* hlo) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 0));
return Status::OK();
}
@@ -383,6 +417,7 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(slice, 1));
return CheckShape(slice,
ShapeInference::InferSliceShape(
slice->operand(0)->shape(), slice->slice_starts(),
@@ -390,6 +425,7 @@ Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
}
Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_slice, 2));
return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape(
dynamic_slice->operand(0)->shape(),
dynamic_slice->operand(1)->shape(),
@@ -398,6 +434,7 @@ Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) {
Status ShapeVerifier::HandleDynamicUpdateSlice(
HloInstruction* dynamic_update_slice) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(dynamic_update_slice, 3));
return CheckShape(dynamic_update_slice,
ShapeInference::InferDynamicUpdateSliceShape(
dynamic_update_slice->operand(0)->shape(),
@@ -427,6 +464,7 @@ Status ShapeVerifier::HandleMap(HloInstruction* map) {
}
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2));
return CheckShape(
reduce_window,
ShapeInference::InferReduceWindowShape(
@@ -436,6 +474,7 @@ Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
}
Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3));
return CheckShape(
instruction,
ShapeInference::InferSelectAndScatterShape(
@@ -446,6 +485,7 @@ Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
}
Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(xla_while, 1));
TF_RETURN_IF_ERROR(
CheckOperandAndParameter(xla_while, 0, xla_while->while_body(), 0));
TF_RETURN_IF_ERROR(
@@ -465,6 +505,7 @@ Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) {
}
Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(conditional, 3));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
conditional, 1, conditional->true_computation(), 0));
TF_RETURN_IF_ERROR(CheckOperandAndParameter(
@@ -479,12 +520,14 @@ Status ShapeVerifier::HandleConditional(HloInstruction* conditional) {
}
Status ShapeVerifier::HandlePad(HloInstruction* pad) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(pad, 2));
return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(),
pad->operand(1)->shape(),
pad->padding_config()));
}
Status ShapeVerifier::HandleSend(HloInstruction* send) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(send, 2));
return CheckShape(send,
ShapeUtil::MakeTupleShape({send->operand(0)->shape(),
ShapeUtil::MakeShape(U32, {}),
@@ -492,10 +535,12 @@ Status ShapeVerifier::HandleSend(HloInstruction* send) {
}
Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(send_done, 1));
return CheckShape(send_done, ShapeUtil::MakeTokenShape());
}
Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(recv, 1));
return CheckShape(
recv, ShapeUtil::MakeTupleShape(
{ShapeUtil::GetTupleElementShape(recv->shape(), 0),
@@ -503,6 +548,7 @@ Status ShapeVerifier::HandleRecv(HloInstruction* recv) {
}
Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(recv_done, 1));
return CheckShape(
recv_done,
ShapeUtil::MakeTupleShape(
@@ -512,6 +558,7 @@ Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) {
Status ShapeVerifier::HandleBatchNormTraining(
HloInstruction* batch_norm_training) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_training, 3));
return CheckShape(batch_norm_training,
ShapeInference::InferBatchNormTrainingShape(
batch_norm_training->operand(0)->shape(),
@@ -522,6 +569,7 @@ Status ShapeVerifier::HandleBatchNormTraining(
Status ShapeVerifier::HandleBatchNormInference(
HloInstruction* batch_norm_inference) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_inference, 5));
return CheckShape(batch_norm_inference,
ShapeInference::InferBatchNormInferenceShape(
batch_norm_inference->operand(0)->shape(),
@@ -533,6 +581,7 @@ Status ShapeVerifier::HandleBatchNormInference(
}
Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(batch_norm_grad, 5));
return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape(
batch_norm_grad->operand(0)->shape(),
batch_norm_grad->operand(1)->shape(),
@@ -601,6 +650,7 @@ Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
} // namespace
Status ShapeVerifier::HandleGather(HloInstruction* gather) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(gather, 2));
return CheckShape(
gather,
ShapeInference::InferGatherShape(
@@ -609,6 +659,7 @@ Status ShapeVerifier::HandleGather(HloInstruction* gather) {
}
Status ShapeVerifier::HandleScatter(HloInstruction* scatter) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(scatter, 3));
return CheckShape(
scatter, ShapeInference::InferScatterShape(
scatter->operand(0)->shape(), scatter->operand(1)->shape(),
@@ -696,12 +747,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 1));
return CheckShape(instruction,
ShapeInference::InferUnaryOpShape(instruction->opcode(),
instruction->operand(0)));
}
Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 2));
return CheckShape(
instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(),
instruction->operand(0),
@@ -709,6 +762,7 @@ Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) {
}
Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) {
+ TF_RETURN_IF_ERROR(CheckOperandCount(instruction, 3));
return CheckShape(instruction,
ShapeInference::InferTernaryOpShape(
instruction->opcode(), instruction->operand(0),
@@ -816,6 +870,47 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
return Status::OK();
}
+// Verifies that entry computation layout matches characteristics of
+// entry computation.
+Status CheckEntryComputationLayout(const HloModule& module) {
+ const HloComputation* computation = module.entry_computation();
+ const auto& layout = module.entry_computation_layout();
+
+ // TODO(117498192): Change into a call to Compatible(...).
+ if (!ShapeUtil::CompatibleIgnoringFpPrecision(
+ computation->root_instruction()->shape(),
+ layout.result_layout().shape())) {
+ return InternalError(
+ "Shape of the root instruction of entry computation (%s) should be "
+ "compatible to one specified in module's entry computation layout (%s)",
+ ShapeUtil::HumanString(computation->root_instruction()->shape()),
+ ShapeUtil::HumanString(layout.result_layout().shape()));
+ }
+
+ if (computation->num_parameters() != layout.parameter_count()) {
+ return InternalError(
+ "Number of parameters in entry computation layout (%d) must be same "
+ "as number of parameters of entry computation computation (%d)",
+ layout.parameter_count(), computation->num_parameters());
+ }
+
+ for (int i = 0; i < computation->num_parameters(); ++i) {
+ if (!ShapeUtil::Compatible(computation->parameter_instruction(i)->shape(),
+ layout.parameter_shape(i))) {
+ return InternalError(
+ "Shape of the entry computation parameter %d is %s should be "
+ "compatible to the one specified in module's entry computation "
+ "layout %s",
+ i,
+ ShapeUtil::HumanString(
+ computation->parameter_instruction(i)->shape()),
+ ShapeUtil::HumanString(layout.parameter_shape(i)));
+ }
+ }
+
+ return Status::OK();
+}
+
// Checks if the given two instructions share the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
@@ -1213,6 +1308,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}
+ TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module));
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));
// If the module has a schedule, it must be valid.