diff options
Diffstat (limited to 'tensorflow/contrib/lite/model.cc')
-rw-r--r-- | tensorflow/contrib/lite/model.cc | 673 |
1 files changed, 673 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc new file mode 100644 index 0000000000..f8208f6f98 --- /dev/null +++ b/tensorflow/contrib/lite/model.cc @@ -0,0 +1,673 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <fcntl.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <sys/mman.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <unistd.h> + +#include "tensorflow/contrib/lite/allocation.h" +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/nnapi_delegate.h" +#include "tensorflow/contrib/lite/version.h" + +namespace tflite { + +const char* kEmptyTensorName = ""; + +std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromFile( + const char* filename, ErrorReporter* error_reporter) { + std::unique_ptr<FlatBufferModel> model; + model.reset(new FlatBufferModel(filename, /*mmap_file=*/true, error_reporter, + /*use_nnapi=*/true)); + if (!model->initialized()) model.reset(); + return model; +} + +std::unique_ptr<FlatBufferModel> FlatBufferModel::BuildFromBuffer( + const char* buffer, size_t buffer_size, ErrorReporter* error_reporter) { + std::unique_ptr<FlatBufferModel> model; + model.reset(new FlatBufferModel(buffer, buffer_size, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + +FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, + ErrorReporter* error_reporter, bool use_nnapi) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + if (mmap_file) { + if (use_nnapi && NNAPIExists()) + allocation_ = new NNAPIAllocation(filename, error_reporter); + else + allocation_ = new MMAPAllocation(filename, error_reporter); + } else { + allocation_ = new FileCopyAllocation(filename, error_reporter); + } + if (!allocation_->valid()) return; + if (!CheckModelIdentifier()) return; + + model_ = ::tflite::GetModel(allocation_->base()); +} + +bool FlatBufferModel::CheckModelIdentifier() const { + if (!tflite::ModelBufferHasIdentifier(allocation_->base())) { + const char* ident = flatbuffers::GetBufferIdentifier(allocation_->base()); + error_reporter_->Report( + "Model provided has model identifier '%c%c%c%c', should be '%s'\n", + ident[0], ident[1], ident[2], ident[3], tflite::ModelIdentifier()); + return false; + } + return true; +} + +FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); + if (!allocation_->valid()) return; + model_ = ::tflite::GetModel(allocation_->base()); +} + +FlatBufferModel::~FlatBufferModel() { delete allocation_; } + +InterpreterBuilder::InterpreterBuilder(const FlatBufferModel& model, + const OpResolver& op_resolver) + : model_(model.GetModel()), + op_resolver_(op_resolver), + error_reporter_(model.error_reporter()), + allocation_(model.allocation()) {} + +InterpreterBuilder::InterpreterBuilder(const ::tflite::Model* model, + const OpResolver& op_resolver, + ErrorReporter* error_reporter) + : model_(model), + op_resolver_(op_resolver), + error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) {} + +TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { + TfLiteStatus status = kTfLiteOk; + auto opcodes = model_->operator_codes(); + for (const OperatorCode* opcode : *opcodes) { + TfLiteRegistration* registration = nullptr; + + if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { + auto x = opcode->builtin_code(); + flatbuffer_op_index_to_registration_types_.push_back(x); + registration = op_resolver_.FindOp(x); + if (registration == nullptr) { + error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", + EnumNameBuiltinOperator(x)); + status = kTfLiteError; + } + } else if (!opcode->custom_code()) { + error_reporter_->Report( + "Operator with builtin_code==0 has no custom_code.\n"); + status = kTfLiteError; + } else { + const char* name = opcode->custom_code()->c_str(); + registration = op_resolver_.FindOp(name); + flatbuffer_op_index_to_registration_types_.push_back( + BuiltinOperator_CUSTOM); + if (registration == nullptr) { + error_reporter_->Report("Didn't find custom op for name '%s'\n", name); + status = kTfLiteError; + } + } + flatbuffer_op_index_to_registration_.push_back(registration); + } + return status; +} + +namespace { +template <class T> +std::vector<int> FlatBufferIntArrayToVector(T* flat_array) { + std::vector<int> ret(flat_array->Length()); + for (int i = 0; i < flat_array->Length(); i++) { + ret[i] = flat_array->Get(i); + } + return ret; +} + +// Allocate a structure using C malloc, but make sure the structure is a +// POD structure that doesn't require constructors to run. The reason we do +// this, is that Interpreter's C extension part will take ownership and wants +// to use malloc() and free(). +template <class T> +T* MallocPOD() { + static_assert(std::is_pod<T>::value, "Builtin data structure must be POD."); + return static_cast<T*>(malloc(sizeof(T))); +} + +// Parse the appropriate data out of the op. +// +// This handles builtin data explicitly as there are flatbuffer schemas. +// +// Returns memory that must be feed. +void* ParseOpData(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter) { + auto parse_padding = [](Padding padding) { + switch (padding) { + case Padding_SAME: + return kTfLitePaddingSame; + case Padding_VALID: + return kTfLitePaddingValid; + } + return kTfLitePaddingUnknown; + }; + auto parse_activation = [](ActivationFunctionType activation) { + switch (activation) { + case ActivationFunctionType_NONE: + return kTfLiteActNone; + case ActivationFunctionType_RELU: + return kTfLiteActRelu; + case ActivationFunctionType_RELU1: + return kTfLiteActRelu1; + case ActivationFunctionType_RELU6: + return kTfLiteActRelu6; + case ActivationFunctionType_TANH: + return kTfLiteActTanh; + case ActivationFunctionType_SIGN_BIT: + return kTfLiteActSignBit; + } + return kTfLiteActNone; + }; + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + void* builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_CALL: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + break; + case BuiltinOperator_CUSTOM: + break; + case BuiltinOperator_CONV_2D: { + TfLiteConvParams* params = MallocPOD<TfLiteConvParams>(); + if (auto* conv_params = op->builtin_options_as_Conv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_TANH: + case BuiltinOperator_LOGISTIC: + case BuiltinOperator_RELU: + case BuiltinOperator_RELU1: + case BuiltinOperator_RELU6: + case BuiltinOperator_CONCAT_EMBEDDINGS: + break; + case BuiltinOperator_LSH_PROJECTION: { + TfLiteLSHProjectionParams* params = + MallocPOD<TfLiteLSHProjectionParams>(); + if (auto* lshParams = op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_AVERAGE_POOL_2D: + case BuiltinOperator_MAX_POOL_2D: + case BuiltinOperator_L2_POOL_2D: { + TfLitePoolParams* params = MallocPOD<TfLitePoolParams>(); + if (auto* pool_params = op->builtin_options_as_Pool2DOptions()) { + params->padding = parse_padding(pool_params->padding()); + params->stride_width = pool_params->stride_w(); + params->stride_height = pool_params->stride_h(); + params->filter_width = pool_params->filter_width(); + params->filter_height = pool_params->filter_height(); + params->activation = + parse_activation(pool_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_DEPTHWISE_CONV_2D: { + TfLiteDepthwiseConvParams* params = + MallocPOD<TfLiteDepthwiseConvParams>(); + if (auto* conv_params = op->builtin_options_as_DepthwiseConv2DOptions()) { + params->padding = parse_padding(conv_params->padding()); + params->stride_width = conv_params->stride_w(); + params->stride_height = conv_params->stride_h(); + params->depth_multiplier = conv_params->depth_multiplier(); + params->activation = + parse_activation(conv_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SVDF: { + TfLiteSVDFParams* params = MallocPOD<TfLiteSVDFParams>(); + if (auto* svdf_params = op->builtin_options_as_SVDFOptions()) { + params->rank = svdf_params->rank(); + params->activation = + parse_activation(svdf_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RNN: { + TfLiteRNNParams* params = MallocPOD<TfLiteRNNParams>(); + if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + parse_activation(rnn_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_EMBEDDING_LOOKUP: + // no-op. + break; + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + TfLiteEmbeddingLookupSparseParams* params = + MallocPOD<TfLiteEmbeddingLookupSparseParams>(); + if (auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_FULLY_CONNECTED: { + TfLiteFullyConnectedParams* params = + MallocPOD<TfLiteFullyConnectedParams>(); + if (auto* fully_connected_params = + op->builtin_options_as_FullyConnectedOptions()) { + params->activation = parse_activation( + fully_connected_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + break; + case BuiltinOperator_SOFTMAX: { + TfLiteSoftmaxParams* params = MallocPOD<TfLiteSoftmaxParams>(); + if (auto* softmax_params = op->builtin_options_as_SoftmaxOptions()) { + params->beta = softmax_params->beta(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_CONCATENATION: { + TfLiteConcatenationParams* params = + MallocPOD<TfLiteConcatenationParams>(); + if (auto* concatenation_params = + op->builtin_options_as_ConcatenationOptions()) { + params->activation = + parse_activation(concatenation_params->fused_activation_function()); + params->axis = concatenation_params->axis(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_MUL: { + auto* params = MallocPOD<TfLiteMulParams>(); + if (auto* schema_params = op->builtin_options_as_MulOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_ADD: { + auto* params = MallocPOD<TfLiteAddParams>(); + if (auto* schema_params = op->builtin_options_as_AddOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_L2_NORMALIZATION: { + auto* params = MallocPOD<TfLiteL2NormParams>(); + if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto* params = MallocPOD<TfLiteLocalResponseNormParams>(); + if (auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_LSTM: { + TfLiteLSTMParams* params = MallocPOD<TfLiteLSTMParams>(); + if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + parse_activation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto* params = MallocPOD<TfLiteResizeBilinearParams>(); + if (auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->new_height = schema_params->new_height(); + params->new_width = schema_params->new_width(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_RESHAPE: { + auto* params = MallocPOD<TfLiteReshapeParams>(); + if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { + auto* new_shape = schema_params->new_shape(); + if (!new_shape) { + error_reporter->Report("No new_shape provided for Reshape\n"); + } else { + params->num_dimensions = new_shape->Length(); + if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in Reshape's new_shape\n"); + } else { + for (int i = 0; i < params->num_dimensions; ++i) { + params->shape[i] = new_shape->Get(i); + } + } + } + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SKIP_GRAM: { + TfLiteSkipGramParams* params = MallocPOD<TfLiteSkipGramParams>(); + if (auto* skip_gram_params = op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto* params = MallocPOD<TfLiteSpaceToDepthParams>(); + if (auto* schema_params = op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + builtin_data = reinterpret_cast<void*>(params); + break; + } + } + return builtin_data; +} + +} // namespace + +TfLiteStatus InterpreterBuilder::ParseNodes( + const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + for (int i = 0; i < operators->Length(); ++i) { + const auto* op = operators->Get(i); + int index = op->opcode_index(); + if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) { + error_reporter_->Report("Missing registration for opcode_index %d\n", + index); + status = kTfLiteError; + continue; + } + const TfLiteRegistration* reg = + flatbuffer_op_index_to_registration_[op->opcode_index()]; + if (reg == nullptr) { + error_reporter_->Report("Skipping op for opcode_index %d\n", index); + status = kTfLiteError; + continue; + } + + auto op_type = + flatbuffer_op_index_to_registration_types_[op->opcode_index()]; + if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) { + error_reporter_->Report( + "Found builtin operator %s with custom options.\n", + EnumNameBuiltinOperator(op_type)); + } + if (op->custom_options()) { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), + reinterpret_cast<const char*>(op->custom_options()->data()), + op->custom_options()->size(), nullptr, reg); + } else { + interpreter->AddNodeWithParameters( + FlatBufferIntArrayToVector(op->inputs()), + FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, + ParseOpData(op, op_type, error_reporter_), reg); + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::ParseTensors( + const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, + const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, + Interpreter* interpreter) { + TfLiteStatus status = kTfLiteOk; + + // A little helper to get the names of inputs and outputs. Note that they + // must outlive the interpreter. + auto get_name = [](const tflite::Tensor* t) -> const char* { + auto name = t->name(); + if (name) return name->c_str(); + return kEmptyTensorName; + }; + + for (int i = 0; i < tensors->Length(); ++i) { + const auto* tensor = tensors->Get(i); + std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape()); + + TfLiteQuantizationParams quantization; + quantization.scale = 0; + quantization.zero_point = 0; + auto* q_params = tensor->quantization(); + if (q_params) { + // Note that the schema could hold per-channel quantization parameters + // but we really only support one value for the whole tensor. + // TODO(aselle): This breaks as well if these are nullptr's. + // TODO(aselle): This assumes non per-channel quantization. + if (q_params->scale()) quantization.scale = q_params->scale()->Get(0); + if (q_params->zero_point()) + quantization.zero_point = q_params->zero_point()->Get(0); + } + + TfLiteType type; + switch (tensor->type()) { + case TensorType_FLOAT32: + type = kTfLiteFloat32; + break; + case TensorType_INT32: + type = kTfLiteInt32; + break; + case TensorType_UINT8: + type = kTfLiteUInt8; + break; + case TensorType_INT64: + type = kTfLiteInt64; + break; + case TensorType_STRING: + type = kTfLiteString; + break; + default: + // tensorType = ArrayType::NONE; + error_reporter_->Report("Unimplemented data type %s (%d) in tensor\n", + EnumNameTensorType(tensor->type()), + tensor->type()); + status = kTfLiteError; + continue; + } + auto get_readonly_data = [&](const char** buffer_data, + size_t* buffer_size) { + // TODO(aselle): Check what happens if we have an unspecified size + // constant. + *buffer_data = nullptr; + if (tensor->buffer() == 0) return kTfLiteOk; + if (tensor->buffer() >= buffers->size()) { + error_reporter_->Report( + "Tensor %d specifies out of range buffer %d (only %d buffers).\n", + i, tensor->buffer(), buffers->size()); + return kTfLiteError; + } + if (auto* buffer = (*buffers)[tensor->buffer()]) { + if (auto* array = buffer->data()) { + if (size_t size = array->size()) { + *buffer_size = size; + *buffer_data = reinterpret_cast<const char*>(array->data()); + return kTfLiteOk; + } + } + } + return kTfLiteOk; + }; + size_t buffer_size = 0; + const char* buffer_ptr; + TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size)); + + if (buffer_ptr) { + if (interpreter->SetTensorParametersReadOnly( + i, type, get_name(tensor), dims, quantization, buffer_ptr, + buffer_size, allocation_) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } else { + if (interpreter->SetTensorParametersReadWrite( + i, type, get_name(tensor), dims, quantization) != kTfLiteOk) { + error_reporter_->Report("Tensor %d is invalidly specified in schema.\n", + i); + status = kTfLiteError; + } + } + } + + return status; +} + +TfLiteStatus InterpreterBuilder::operator()( + std::unique_ptr<Interpreter>* interpreter) { + if (!interpreter) { + error_reporter_->Report( + "Null output pointer passed to InterpreterBuilder."); + return kTfLiteError; + } + + // Safe exit by deleting partially created interpreter, to reduce verbosity + // on error conditions. Use by return cleanup_on_error(); + auto cleanup_and_error = [&interpreter]() { + interpreter->reset(); + return kTfLiteError; + }; + + if (!model_) { + error_reporter_->Report("Null pointer passed in as model."); + return cleanup_and_error(); + } + + if (model_->version() != TFLITE_SCHEMA_VERSION) { + error_reporter_->Report( + "Model provided is schema version %d not equal " + "to supported version %d.\n", + model_->version(), TFLITE_SCHEMA_VERSION); + return cleanup_and_error(); + } + + if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) { + error_reporter_->Report("Registration failed.\n"); + return cleanup_and_error(); + } + + // Flatbuffer model schemas define a list of opcodes independent of the graph. + // We first map those to registrations. This reduces string lookups for custom + // ops since we only do it once per custom op rather than once per custom op + // invocation in the model graph. + // Construct interpreter with correct number of tensors and operators. + auto* subgraphs = model_->subgraphs(); + auto* buffers = model_->buffers(); + if (subgraphs->size() != 1) { + error_reporter_->Report("Only 1 subgraph is currently supported.\n"); + return cleanup_and_error(); + } + const tflite::SubGraph* subgraph = (*subgraphs)[0]; + auto operators = subgraph->operators(); + auto tensors = subgraph->tensors(); + if (!operators || !tensors || !buffers) { + error_reporter_->Report( + "Did not get operators, tensors, or buffers in input flat buffer.\n"); + return cleanup_and_error(); + } + interpreter->reset(new Interpreter(error_reporter_)); + if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) { + return cleanup_and_error(); + } + + // Parse inputs/outputs + (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); + (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); + + // Finally setup nodes and tensors + if (ParseNodes(operators, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk) + return cleanup_and_error(); + + return kTfLiteOk; +} + +} // namespace tflite |