aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/model.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-05-13 19:52:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-13 19:55:02 -0700
commit699b217cd6c5ddc0832be8471dde47999829e435 (patch)
tree035167be1ec270dded665347d20ec9385bed0fcc /tensorflow/contrib/lite/model.cc
parent2fbc0c5a45955c877e0a165bb561fc2f01518321 (diff)
Introduce op version into TFLite
PiperOrigin-RevId: 196448769
Diffstat (limited to 'tensorflow/contrib/lite/model.cc')
-rw-r--r--tensorflow/contrib/lite/model.cc27
1 files changed, 17 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 1fbf965004..5d0fe3839e 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -186,6 +186,8 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
for (const OperatorCode* opcode : *opcodes) {
TfLiteRegistration* registration = nullptr;
auto builtin_code = opcode->builtin_code();
+ int version = opcode->version();
+
if (builtin_code > BuiltinOperator_MAX ||
builtin_code < BuiltinOperator_MIN) {
error_reporter_->Report(
@@ -194,8 +196,7 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
builtin_code);
status = kTfLiteError;
} else if (builtin_code != BuiltinOperator_CUSTOM) {
- flatbuffer_op_index_to_registration_types_.push_back(builtin_code);
- registration = op_resolver_.FindOp(builtin_code);
+ registration = op_resolver_.FindOp(builtin_code, version);
if (registration == nullptr) {
error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
EnumNameBuiltinOperator(builtin_code));
@@ -207,11 +208,13 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
status = kTfLiteError;
} else {
const char* name = opcode->custom_code()->c_str();
- registration = op_resolver_.FindOp(name);
+ registration = op_resolver_.FindOp(name, version);
flatbuffer_op_index_to_registration_types_.push_back(
BuiltinOperator_CUSTOM);
if (registration == nullptr) {
- error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
+ error_reporter_->Report(
+ "Didn't find custom op for name '%s' with version %d\n", name,
+ version);
status = kTfLiteError;
}
}
@@ -333,6 +336,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
params->stride_height = conv_params->stride_h();
params->activation =
parse_activation(conv_params->fused_activation_function());
+
params->dilation_width_factor = conv_params->dilation_w_factor();
params->dilation_height_factor = conv_params->dilation_h_factor();
}
@@ -707,27 +711,30 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
status = kTfLiteError;
continue;
}
- const TfLiteRegistration* reg =
+
+ TfLiteRegistration* registration =
flatbuffer_op_index_to_registration_[op->opcode_index()];
- if (reg == nullptr) {
+ if (registration == nullptr) {
error_reporter_->Report("Skipping op for opcode_index %d\n", index);
status = kTfLiteError;
continue;
}
- auto op_type =
- flatbuffer_op_index_to_registration_types_[op->opcode_index()];
+ BuiltinOperator op_type =
+ static_cast<BuiltinOperator>(registration->builtin_code);
+
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
error_reporter_->Report(
"Found builtin operator %s with custom options.\n",
EnumNameBuiltinOperator(op_type));
}
+
if (op->custom_options()) {
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
reinterpret_cast<const char*>(op->custom_options()->data()),
- op->custom_options()->size(), nullptr, reg);
+ op->custom_options()->size(), nullptr, registration);
} else {
void* builtin_data = nullptr;
TF_LITE_ENSURE_STATUS(
@@ -735,7 +742,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
interpreter->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
- reg);
+ registration);
}
}