diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 050d28b289..09bcf8a9e7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -305,6 +305,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.tuple_index()); break; case HloOpcode::kReducePrecision: + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "ReducePrecision instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReducePrecision(proto.shape(), operands(0), proto.exponent_bits(), proto.mantissa_bits()); @@ -312,12 +315,16 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( case HloOpcode::kInfeed: { const Shape& data_shape = ShapeUtil::GetTupleElementShape(proto.shape(), 0); - TF_RET_CHECK(proto.operand_ids_size() == 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Infeed instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } break; case HloOpcode::kOutfeed: - TF_RET_CHECK(proto.operand_ids_size() == 2); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Outfeed instruction should have 2 operands but sees " + << proto.operand_ids_size(); TF_RETURN_IF_ERROR( ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape())); instruction = CreateOutfeed(proto.outfeed_shape(), operands(0), @@ -349,6 +356,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kCollectivePermute: { + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "CollectivePermute instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector<std::pair<int64, int64>> source_target_pairs( proto.source_target_pairs_size()); for (int i = 0; i < source_target_pairs.size(); i++) { |