diff options
author | 2018-06-18 19:14:12 -0700 | |
---|---|---|
committer | 2018-06-18 19:17:25 -0700 | |
commit | 98a829817c027b9681a728160c746bcc63ad86b9 (patch) | |
tree | fd4b75f8a17bed95d8af5c45b559cb076a6f5d08 | |
parent | 36bf4a43248077fd5635b13e2def636be299e435 (diff) |
HloInstruction::CreateFromProto should not crash on CHECK, instead needs to return error status.
PiperOrigin-RevId: 201100918
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 100 |
1 files changed, 73 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 58a33f5229..1dd2ce40da 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -70,25 +70,33 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( switch (opcode) { // Ops migrated to subclasses. case HloOpcode::kBatchNormTraining: - CHECK_EQ(proto.operand_ids_size(), 3); + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "BatchNormTraining instruction should have 3 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormTraining( proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormInference: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormInference instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormInference( proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kBatchNormGrad: - CHECK_EQ(proto.operand_ids_size(), 5); + TF_RET_CHECK(proto.operand_ids_size() == 5) + << "BatchNormGrad instruction should have 5 operands but sees " + << proto.operand_ids_size(); instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1), operands(2), operands(3), operands(4), proto.epsilon(), proto.feature_index()); break; case HloOpcode::kFft: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Fft instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector<int64> fft_length(proto.fft_length().begin(), proto.fft_length().end()); instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(), @@ -96,30 +104,42 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Send instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateSend(operands(0), proto.channel_id()); break; case HloOpcode::kSendDone: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "SendDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateSendDone(operands(0)); break; case HloOpcode::kRecv: - CHECK_EQ(proto.operand_ids_size(), 0); + TF_RET_CHECK(proto.operand_ids_size() == 0) + << "Recv instruction should have 0 operand but sees " + << proto.operand_ids_size(); instruction = CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); break; case HloOpcode::kRecvDone: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "RecvDone instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateRecvDone(operands(0)); break; case HloOpcode::kReverse: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Reverse instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateReverse(proto.shape(), operands(0), std::vector<int64>(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kConcatenate: { - CHECK_EQ(proto.dimensions_size(), 1); + TF_RET_CHECK(proto.dimensions_size() == 1) + << "Concatenate instruction should have 1 dimension but sees " + << proto.dimensions_size(); std::vector<HloInstruction*> concat_operands(proto.operand_ids_size()); std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), concat_operands.begin(), @@ -131,29 +151,39 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kReduce: - CHECK_EQ(proto.operand_ids_size(), 2); - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Reduce instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Reduce instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); instruction = CreateReduce(proto.shape(), operands(0), operands(1), std::vector<int64>(proto.dimensions().begin(), proto.dimensions().end()), computations(0)); break; case HloOpcode::kTranspose: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Transpose instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateTranspose(proto.shape(), operands(0), std::vector<int64>(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kBroadcast: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Broadcast instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateBroadcast(proto.shape(), operands(0), std::vector<int64>(proto.dimensions().begin(), proto.dimensions().end())); break; case HloOpcode::kMap: { - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "Map instruction should have 1 called computation but sees " + << proto.called_computation_ids_size(); std::vector<HloInstruction*> map_operands(proto.operand_ids_size()); std::transform(proto.operand_ids().begin(), proto.operand_ids().end(), map_operands.begin(), @@ -164,7 +194,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kSlice: { - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Slice instruction should have 1 operand but sees " + << proto.operand_ids_size(); std::vector<int64> slice_starts, slice_limits, slice_strides; for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { @@ -191,7 +223,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( TF_RET_CHECK(proto.operand_ids_size() == 1) << "Trace instruction should have 1 operand but sees " << proto.operand_ids_size(); - CHECK(proto.has_literal()); + TF_RET_CHECK(proto.has_literal()); TF_ASSIGN_OR_RETURN(auto literal, Literal::CreateFromProto(proto.literal())); instruction = CreateTrace(literal->GetR1U8AsString(), operands(0)); @@ -207,7 +239,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( // Find the fused computation and set its fusion instruction. TF_RET_CHECK(proto.called_computation_ids_size() == 1) - << "Expect 1 called computation for fusion instruction, but sees " + << "Expect 1 called computation for fusion instruction but sees " << proto.called_computation_ids_size(); const int64 fusion_id = proto.called_computation_ids(0); auto* fused_computation = FindPtrOrNull(computation_map, fusion_id); @@ -237,7 +269,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.name()); break; case HloOpcode::kGetTupleElement: - CHECK_EQ(proto.operand_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "GetTupleElement instruction should have 1 operand but sees " + << proto.operand_ids_size(); instruction = CreateGetTupleElement(proto.shape(), operands(0), proto.tuple_index()); break; @@ -254,7 +288,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.outfeed_config()); break; case HloOpcode::kCrossReplicaSum: { - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "CrossReplicaSum should have 1 called computation but sees " + << proto.called_computation_ids_size(); std::vector<HloInstruction*> all_operands(proto.operand_ids_size()); c_transform(proto.operand_ids(), all_operands.begin(), [&instruction_map](int64 operand_id) { @@ -274,22 +310,32 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kConvolution: - CHECK_EQ(proto.operand_ids_size(), 2); - CHECK(proto.has_window()); - CHECK(proto.has_convolution_dimension_numbers()); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Convolution instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_window()); + TF_RET_CHECK(proto.has_convolution_dimension_numbers()); instruction = CreateConvolve(proto.shape(), operands(0), operands(1), proto.window(), proto.convolution_dimension_numbers()); break; case HloOpcode::kReduceWindow: - CHECK_EQ(proto.operand_ids_size(), 2); - CHECK_EQ(proto.called_computation_ids_size(), 1); + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "ReduceWindow instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 1) + << "ReduceWindow should have 1 called computation but sees " + << proto.called_computation_ids_size(); instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1), proto.window(), computations(0)); break; case HloOpcode::kSelectAndScatter: - CHECK_EQ(proto.operand_ids_size(), 3); - CHECK_EQ(proto.called_computation_ids_size(), 2); + TF_RET_CHECK(proto.operand_ids_size() == 3) + << "SelectAndScatter instruction should have 3 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.called_computation_ids_size() == 2) + << "SelectAndScatter should have 2 called computations but sees " + << proto.called_computation_ids_size(); instruction = CreateSelectAndScatter( proto.shape(), operands(0), computations(0), proto.window(), operands(1), operands(2), computations(1)); |