diff options
author | 2018-05-13 19:52:18 -0700 | |
---|---|---|
committer | 2018-05-13 19:55:02 -0700 | |
commit | 699b217cd6c5ddc0832be8471dde47999829e435 (patch) | |
tree | 035167be1ec270dded665347d20ec9385bed0fcc /tensorflow/contrib/lite/model.cc | |
parent | 2fbc0c5a45955c877e0a165bb561fc2f01518321 (diff) |
Introduce op version into TFLite
PiperOrigin-RevId: 196448769
Diffstat (limited to 'tensorflow/contrib/lite/model.cc')
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 1fbf965004..5d0fe3839e 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -186,6 +186,8 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { for (const OperatorCode* opcode : *opcodes) { TfLiteRegistration* registration = nullptr; auto builtin_code = opcode->builtin_code(); + int version = opcode->version(); + if (builtin_code > BuiltinOperator_MAX || builtin_code < BuiltinOperator_MIN) { error_reporter_->Report( @@ -194,8 +196,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { builtin_code); status = kTfLiteError; } else if (builtin_code != BuiltinOperator_CUSTOM) { - flatbuffer_op_index_to_registration_types_.push_back(builtin_code); - registration = op_resolver_.FindOp(builtin_code); + registration = op_resolver_.FindOp(builtin_code, version); if (registration == nullptr) { error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", EnumNameBuiltinOperator(builtin_code)); @@ -207,11 +208,13 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { status = kTfLiteError; } else { const char* name = opcode->custom_code()->c_str(); - registration = op_resolver_.FindOp(name); + registration = op_resolver_.FindOp(name, version); flatbuffer_op_index_to_registration_types_.push_back( BuiltinOperator_CUSTOM); if (registration == nullptr) { - error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + error_reporter_->Report( + "Didn't find custom op for name '%s' with version %d\n", name, + version); status = kTfLiteError; } } @@ -333,6 +336,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, params->stride_height = conv_params->stride_h(); params->activation = parse_activation(conv_params->fused_activation_function()); + params->dilation_width_factor = conv_params->dilation_w_factor(); params->dilation_height_factor = conv_params->dilation_h_factor(); } @@ -707,27 +711,30 @@ TfLiteStatus InterpreterBuilder::ParseNodes( status = kTfLiteError; continue; } - const TfLiteRegistration* reg = + + TfLiteRegistration* registration = flatbuffer_op_index_to_registration_[op->opcode_index()]; - if (reg == nullptr) { + if (registration == nullptr) { error_reporter_->Report("Skipping op for opcode_index %d\n", index); status = kTfLiteError; continue; } - auto op_type = - flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + BuiltinOperator op_type = + static_cast<BuiltinOperator>(registration->builtin_code); + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { error_reporter_->Report( "Found builtin operator %s with custom options.\n", EnumNameBuiltinOperator(op_type)); } + if (op->custom_options()) { interpreter->AddNodeWithParameters( FlatBufferIntArrayToVector(op->inputs()), FlatBufferIntArrayToVector(op->outputs()), reinterpret_cast<const char*>(op->custom_options()->data()), - op->custom_options()->size(), nullptr, reg); + op->custom_options()->size(), nullptr, registration); } else { void* builtin_data = nullptr; TF_LITE_ENSURE_STATUS( @@ -735,7 +742,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes( interpreter->AddNodeWithParameters( FlatBufferIntArrayToVector(op->inputs()), FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data, - reg); + registration); } } |