aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-18 19:14:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-18 19:17:25 -0700
commit98a829817c027b9681a728160c746bcc63ad86b9 (patch)
treefd4b75f8a17bed95d8af5c45b559cb076a6f5d08
parent36bf4a43248077fd5635b13e2def636be299e435 (diff)
HloInstruction::CreateFromProto should not crash on CHECK, instead needs to return error status.
PiperOrigin-RevId: 201100918
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc100
1 files changed, 73 insertions, 27 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 58a33f5229..1dd2ce40da 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -70,25 +70,33 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
- CHECK_EQ(proto.operand_ids_size(), 3);
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "BatchNormTraining instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
instruction = CreateBatchNormTraining(
proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
proto.feature_index());
break;
case HloOpcode::kBatchNormInference:
- CHECK_EQ(proto.operand_ids_size(), 5);
+ TF_RET_CHECK(proto.operand_ids_size() == 5)
+ << "BatchNormInference instruction should have 5 operands but sees "
+ << proto.operand_ids_size();
instruction = CreateBatchNormInference(
proto.shape(), operands(0), operands(1), operands(2), operands(3),
operands(4), proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormGrad:
- CHECK_EQ(proto.operand_ids_size(), 5);
+ TF_RET_CHECK(proto.operand_ids_size() == 5)
+ << "BatchNormGrad instruction should have 5 operands but sees "
+ << proto.operand_ids_size();
instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
operands(2), operands(3), operands(4),
proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kFft: {
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Fft instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
@@ -96,30 +104,42 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kSend:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Send instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction = CreateSend(operands(0), proto.channel_id());
break;
case HloOpcode::kSendDone:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "SendDone instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction = CreateSendDone(operands(0));
break;
case HloOpcode::kRecv:
- CHECK_EQ(proto.operand_ids_size(), 0);
+ TF_RET_CHECK(proto.operand_ids_size() == 0)
+ << "Recv instruction should have 0 operand but sees "
+ << proto.operand_ids_size();
instruction =
CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id());
break;
case HloOpcode::kRecvDone:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "RecvDone instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction = CreateRecvDone(operands(0));
break;
case HloOpcode::kReverse:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Reverse instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction = CreateReverse(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kConcatenate: {
- CHECK_EQ(proto.dimensions_size(), 1);
+ TF_RET_CHECK(proto.dimensions_size() == 1)
+ << "Concatenate instruction should have 1 dimension but sees "
+ << proto.dimensions_size();
std::vector<HloInstruction*> concat_operands(proto.operand_ids_size());
std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
concat_operands.begin(),
@@ -131,29 +151,39 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kReduce:
- CHECK_EQ(proto.operand_ids_size(), 2);
- CHECK_EQ(proto.called_computation_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Reduce instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Reduce instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
instruction = CreateReduce(proto.shape(), operands(0), operands(1),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()),
computations(0));
break;
case HloOpcode::kTranspose:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Transpose instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction =
CreateTranspose(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kBroadcast:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Broadcast instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction =
CreateBroadcast(proto.shape(), operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
case HloOpcode::kMap: {
- CHECK_EQ(proto.called_computation_ids_size(), 1);
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "Map instruction should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
std::vector<HloInstruction*> map_operands(proto.operand_ids_size());
std::transform(proto.operand_ids().begin(), proto.operand_ids().end(),
map_operands.begin(),
@@ -164,7 +194,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kSlice: {
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "Slice instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
std::vector<int64> slice_starts, slice_limits, slice_strides;
for (const HloInstructionProto::SliceDimensions& slice_dimensions :
proto.slice_dimensions()) {
@@ -191,7 +223,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Trace instruction should have 1 operand but sees "
<< proto.operand_ids_size();
- CHECK(proto.has_literal());
+ TF_RET_CHECK(proto.has_literal());
TF_ASSIGN_OR_RETURN(auto literal,
Literal::CreateFromProto(proto.literal()));
instruction = CreateTrace(literal->GetR1U8AsString(), operands(0));
@@ -207,7 +239,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// Find the fused computation and set its fusion instruction.
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
- << "Expect 1 called computation for fusion instruction, but sees "
+ << "Expect 1 called computation for fusion instruction but sees "
<< proto.called_computation_ids_size();
const int64 fusion_id = proto.called_computation_ids(0);
auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
@@ -237,7 +269,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.name());
break;
case HloOpcode::kGetTupleElement:
- CHECK_EQ(proto.operand_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 1)
+ << "GetTupleElement instruction should have 1 operand but sees "
+ << proto.operand_ids_size();
instruction = CreateGetTupleElement(proto.shape(), operands(0),
proto.tuple_index());
break;
@@ -254,7 +288,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.outfeed_config());
break;
case HloOpcode::kCrossReplicaSum: {
- CHECK_EQ(proto.called_computation_ids_size(), 1);
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "CrossReplicaSum should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
std::vector<HloInstruction*> all_operands(proto.operand_ids_size());
c_transform(proto.operand_ids(), all_operands.begin(),
[&instruction_map](int64 operand_id) {
@@ -274,22 +310,32 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kConvolution:
- CHECK_EQ(proto.operand_ids_size(), 2);
- CHECK(proto.has_window());
- CHECK(proto.has_convolution_dimension_numbers());
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "Convolution instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_window());
+ TF_RET_CHECK(proto.has_convolution_dimension_numbers());
instruction =
CreateConvolve(proto.shape(), operands(0), operands(1),
proto.window(), proto.convolution_dimension_numbers());
break;
case HloOpcode::kReduceWindow:
- CHECK_EQ(proto.operand_ids_size(), 2);
- CHECK_EQ(proto.called_computation_ids_size(), 1);
+ TF_RET_CHECK(proto.operand_ids_size() == 2)
+ << "ReduceWindow instruction should have 2 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 1)
+ << "ReduceWindow should have 1 called computation but sees "
+ << proto.called_computation_ids_size();
instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1),
proto.window(), computations(0));
break;
case HloOpcode::kSelectAndScatter:
- CHECK_EQ(proto.operand_ids_size(), 3);
- CHECK_EQ(proto.called_computation_ids_size(), 2);
+ TF_RET_CHECK(proto.operand_ids_size() == 3)
+ << "SelectAndScatter instruction should have 3 operands but sees "
+ << proto.operand_ids_size();
+ TF_RET_CHECK(proto.called_computation_ids_size() == 2)
+ << "SelectAndScatter should have 2 called computations but sees "
+ << proto.called_computation_ids_size();
instruction = CreateSelectAndScatter(
proto.shape(), operands(0), computations(0), proto.window(),
operands(1), operands(2), computations(1));