aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-06-03 12:43:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-03 12:46:08 -0700
commit45198062b58245711d7446aa389f3b9aa2c1535f (patch)
tree68dd4385a80286959f5cc6923aac5a86521b45b1 /tensorflow/contrib/lite/delegates
parentd23f115d89ad6111674f53135d669cb2d2c086f0 (diff)
New NN API interface that uses the TensorFlow Lite delegate API.
- Make nn_api a delegate in its own directory. - Use the delegate API to rewrite the graph. - Use only on static APIs right now. - This is initial preview of the delegate that only supports add and conv. PiperOrigin-RevId: 199055747
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/BUILD31
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc464
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h31
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc82
4 files changed, 608 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/delegates/nnapi/BUILD b/tensorflow/contrib/lite/delegates/nnapi/BUILD
new file mode 100644
index 0000000000..35a8f6ca41
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = [
+ "//visibility:public",
+])
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+licenses(["notice"]) # Apache 2.0
+
+cc_library(
+ name = "nnapi_delegate",
+ srcs = ["nnapi_delegate.cc"],
+ hdrs = ["nnapi_delegate.h"],
+ deps = [
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite:kernel_api",
+ "//tensorflow/contrib/lite/kernels:kernel_util",
+ "//tensorflow/contrib/lite/nnapi:nnapi_lib",
+ ],
+)
+
+tf_cc_test(
+ name = "nnapi_delegate_test",
+ size = "small",
+ srcs = ["nnapi_delegate_test.cc"],
+ deps = [
+ ":nnapi_delegate",
+ "//tensorflow/contrib/lite:framework",
+ "//tensorflow/contrib/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
new file mode 100644
index 0000000000..0731d14419
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -0,0 +1,464 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstdarg>
+#include <iostream>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/contrib/lite/allocation.h"
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/builtin_ops.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/context_util.h"
+#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+
+namespace tflite {
+namespace {
+
+// TODO(b/80621585): Consider printing error string, but don't for now to
+// minimize binary size.
+#define CHECK_NN(context, code) \
+ if (code != ANEURALNETWORKS_NO_ERROR) { \
+ context->ReportError(context, "NN API returned error (%d).\n", code); \
+ return kTfLiteError; \
+ }
+
+// RAII NN API Model Destructor for use with std::unique_ptr
+struct NNFreeModel {
+ void operator()(ANeuralNetworksModel* model) {
+ ANeuralNetworksModel_free(model);
+ }
+};
+// RAII NN API Compilation Destructor for use with std::unique_ptr
+struct NNFreeCompilation {
+ void operator()(ANeuralNetworksCompilation* model) {
+ ANeuralNetworksCompilation_free(model);
+ }
+};
+
+// Track tensor indices to NN API tensor indices mapping.
+class OperandMapping {
+ public:
+ // Given a TFLite index return the ANN index. If it doesn't exist
+ // return -1.
+ int lite_index_to_ann(int index) const {
+ if (index < lite_tensor_to_ann_tensor_.size())
+ return lite_tensor_to_ann_tensor_[index];
+ else
+ return -1;
+ }
+
+ // NN API uses non tensor operands instead of structs. This creates one
+ // and returns the index. It uses a std::vector and resizes it as needed
+ // keeping -1 to unmapped values. Intermediate tensors likely will not
+ // be mapped.
+ int add_new_non_tensor_operand() { return next_ann_tensor_index_++; }
+
+ // Add a new mapping from `tflite_index` and return the NN API tensor index.
+ int add_new_ann_tensor_index(int tflite_index) {
+ if (tflite_index >= lite_tensor_to_ann_tensor_.size()) {
+ lite_tensor_to_ann_tensor_.resize(tflite_index + 1);
+ }
+ int new_tensor_index = next_ann_tensor_index_++;
+ lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index;
+ return new_tensor_index;
+ }
+
+ private:
+ // Next index of ann tensor
+ int next_ann_tensor_index_ = 0;
+
+ // Mapping from lite index. Use a std::vector for speed and code size
+ // rather than a map.
+ std::vector<int> lite_tensor_to_ann_tensor_;
+};
+
+// Abstract builder for building an op in the NN API graph. This handles
+// the disparity between TFLite and NN API operand types. NN API has singular
+// operands for both tensors and parameters, and TFLite separates the two.
+class NNAPIOpBuilder {
+ public:
+ NNAPIOpBuilder(TfLiteContext* context, OperandMapping* tensor_mapping,
+ ANeuralNetworksModel* nn_model)
+ : context_(context),
+ operand_mapping_(tensor_mapping),
+ nn_model_(nn_model) {}
+
+ TfLiteStatus AddScalarInt32Operand(int value) {
+ ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32};
+ CHECK_NN(context_,
+ ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+ int ann_operand = operand_mapping_->add_new_non_tensor_operand();
+ CHECK_NN(context_, ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_operand, &value, sizeof(int32_t)));
+ augmented_inputs_.push_back(ann_operand);
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus AddTensorInput(int tensor_index) {
+ int ann_index;
+ TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
+ augmented_inputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus AddTensorOutput(int tensor_index) {
+ int ann_index;
+ TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
+ augmented_outputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
+
+ // Adds a new NN API tensor that shadows the TF Lite tensor `tensor_index`.
+ // This returns the NN API tensor index corresponding to the created tensor.
+ // If another caller previously created a NN API tensor for `tensor_index`
+ // then the existing one is returned.
+ TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) {
+ int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index);
+ if (ann_tensor_index != -1) {
+ *ann_tensor_index_out = ann_tensor_index;
+ return kTfLiteOk;
+ }
+ // Allocate a new tensor index
+ ann_tensor_index = operand_mapping_->add_new_ann_tensor_index(tensor_index);
+
+ // Parameters needed for new type.
+ int32_t nn_type = 0;
+ float scale = 0.0f;
+ int32_t zeroPoint = 0;
+ TfLiteTensor* tensor = &context_->tensors[tensor_index];
+ switch (tensor->type) {
+ case kTfLiteNoType:
+ // Tensors added during initialization of Ops don't have a type yet and
+ // should not be registered with the NNAPI.
+ *ann_tensor_index_out = -1;
+ return kTfLiteOk;
+ case kTfLiteFloat32:
+ nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
+ scale = 0.f;
+ break;
+ case kTfLiteUInt8:
+ nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM;
+ scale = tensor->params.scale;
+ zeroPoint = tensor->params.zero_point;
+ break;
+ case kTfLiteInt32:
+ nn_type = ANEURALNETWORKS_TENSOR_INT32;
+ scale = 0.f;
+ zeroPoint = 0;
+ break;
+ default:
+ context_->ReportError(context_, "Logic error in NN API Delegate.\n");
+ return kTfLiteError;
+ }
+
+ ANeuralNetworksOperandType operand_type{
+ nn_type, static_cast<uint32_t>(tensor->dims->size),
+ reinterpret_cast<uint32_t*>(tensor->dims->data), scale, zeroPoint};
+ CHECK_NN(context_,
+ ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+
+ if (tensor->allocation_type == kTfLiteMmapRo) {
+ // TODO(b/80630405): Use NNAPIAllocation.
+ CHECK_NN(context_, ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_tensor_index, tensor->data.raw,
+ tensor->bytes));
+ }
+
+ *ann_tensor_index_out = ann_tensor_index;
+ return kTfLiteOk;
+ }
+
+ // Finish emitting the op (of type `type`) into the NN API.
+ TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) {
+ // Actually add a NN API operation
+ CHECK_NN(context_, ANeuralNetworksModel_addOperation(
+ nn_model_, type,
+ static_cast<uint32_t>(augmented_inputs_.size()),
+ augmented_inputs_.data(),
+ static_cast<uint32_t>(augmented_outputs_.size()),
+ augmented_outputs_.data()));
+ augmented_outputs_.clear();
+ augmented_outputs_.clear();
+ return kTfLiteOk;
+ }
+
+ private:
+ // TfLiteContext for error handling. Must be named context for macros to
+ // work.
+ TfLiteContext* context_;
+
+ // Tracks relationship between indices
+ OperandMapping* operand_mapping_;
+
+ // The model
+ ANeuralNetworksModel* nn_model_;
+
+ // Inputs and outputs for the current op. These are augmented in the sense
+ // that NN API uses operands for all arguments, not just tensors, unlike
+ // TensorFlow lite.
+ std::vector<uint32_t> augmented_inputs_;
+ std::vector<uint32_t> augmented_outputs_;
+};
+
+// The kernel that represents the subgraph of TF Lite being run on NN API.
+class NNAPIDelegateKernel {
+ public:
+ NNAPIDelegateKernel() = default;
+
+ typedef ANeuralNetworksOperationType (*MappingFn)(TfLiteContext*,
+ NNAPIOpBuilder* builder,
+ TfLiteNode* node);
+
+ // Return a function that knows how to translate a node into its operands
+ // when called. You can use this function to see if a node is supported
+ // (i.e. that MappingFn is not nullptr).
+ MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) {
+ switch (builtin_code) {
+ case kTfLiteBuiltinAdd:
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin = reinterpret_cast<TfLiteAddParams*>(node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_ADD;
+ };
+ break;
+ case kTfLiteBuiltinAveragePool2d:
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ auto builtin =
+ reinterpret_cast<TfLitePoolParams*>(node->builtin_data);
+ builder->AddScalarInt32Operand(builtin->padding);
+ builder->AddScalarInt32Operand(builtin->stride_width);
+ builder->AddScalarInt32Operand(builtin->stride_height);
+ builder->AddScalarInt32Operand(builtin->filter_width);
+ builder->AddScalarInt32Operand(builtin->filter_height);
+ builder->AddScalarInt32Operand(builtin->activation);
+ return ANEURALNETWORKS_AVERAGE_POOL_2D;
+ };
+ break;
+ default:
+ return nullptr;
+ }
+ }
+
+ // Initialize the kernel (a NN model).
+ TfLiteStatus Init(TfLiteContext* context,
+ const TfLiteDelegateParams* params) {
+ for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
+ nodes_.push_back(node_index);
+ }
+
+ if (!nn_model_) {
+ ANeuralNetworksModel* model;
+ CHECK_NN(context, ANeuralNetworksModel_create(&model));
+ nn_model_.reset(model);
+
+ TF_LITE_ENSURE_STATUS(
+ BuildGraph(context, params->input_tensors, params->output_tensors));
+ }
+
+ if (!nn_compilation_) {
+ ANeuralNetworksCompilation* compilation;
+ CHECK_NN(context, ANeuralNetworksCompilation_create(nn_model_.get(),
+ &compilation));
+ CHECK_NN(context, ANeuralNetworksCompilation_finish(compilation));
+ nn_compilation_.reset(compilation);
+ }
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus Invoke(TfLiteContext* context, TfLiteNode* node) {
+ ANeuralNetworksExecution* execution = nullptr;
+ CHECK_NN(context, ANeuralNetworksExecution_create(nn_compilation_.get(),
+ &execution));
+
+ // Set the input tensor buffers. Note: we access tflite tensors using
+ // absolute indices but NN api indices inputs by relative indices.
+ int relative_input_index = 0;
+ for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) {
+ TfLiteTensor* tensor = &context->tensors[absolute_input_index];
+ CHECK_NN(context, ANeuralNetworksExecution_setInput(
+ execution, relative_input_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_input_index++;
+ }
+
+ // Set the output tensor buffers.
+ int relative_output_index = 0;
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ TfLiteTensor* tensor = &context->tensors[output_index];
+ CHECK_NN(context, ANeuralNetworksExecution_setOutput(
+ execution, relative_output_index, nullptr,
+ tensor->data.raw, tensor->bytes));
+ relative_output_index++;
+ }
+ // Invoke ANN in blocking fashion.
+ ANeuralNetworksEvent* event = nullptr;
+ CHECK_NN(context, ANeuralNetworksExecution_startCompute(execution, &event));
+ CHECK_NN(context, ANeuralNetworksEvent_wait(event));
+ ANeuralNetworksEvent_free(event);
+ ANeuralNetworksExecution_free(execution);
+
+ return kTfLiteOk;
+ }
+
+ private:
+ // ANN API state.
+ std::unique_ptr<ANeuralNetworksModel, NNFreeModel> nn_model_;
+ std::unique_ptr<ANeuralNetworksCompilation, NNFreeCompilation>
+ nn_compilation_;
+ // Node indices that this delegate is responsible for. Indices here
+ // indexes into the nodes array in the TfLiteContext.
+ std::vector<int> nodes_;
+ // Track indices we use
+ OperandMapping operand_mapping_;
+
+ TfLiteStatus AddOpsAndTensors(TfLiteContext* context) {
+ // The operand builder allows creating a single op. We create it at this
+ // reduced power position rather than in the for loop to avoid reallocating
+ // the vectors.
+ NNAPIOpBuilder builder(context, &operand_mapping_, nn_model_.get());
+ // Add Tensors
+ // allocate outside to avoid realloc
+ for (auto node_index : nodes_) {
+ // Obtain the op and registration.
+ TfLiteNode* node;
+ TfLiteRegistration* reg;
+ context->GetNodeAndRegistration(context, node_index, &node, &reg);
+ // Map inputs to NN API tensor indices.
+ for (auto input_index : TfLiteIntArrayView(node->inputs)) {
+ TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index));
+ }
+ // Get op type and operands
+ int nn_op_type =
+ Map(context, reg->builtin_code, node)(context, &builder, node);
+ // Map outputs to NN API tensor indices.
+ for (auto output_index : TfLiteIntArrayView(node->outputs)) {
+ TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
+ }
+
+ builder.FinalizeAddOperation(nn_op_type);
+ }
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus BuildGraph(TfLiteContext* context,
+ const TfLiteIntArray* input_tensors,
+ const TfLiteIntArray* output_tensors) {
+ // Build the ops and tensors.
+ TF_LITE_ENSURE_STATUS(AddOpsAndTensors(context));
+ // Map input and output tensor indices to ANN
+ std::vector<uint32_t> inputs;
+ inputs.reserve(input_tensors->size);
+ std::vector<uint32_t> outputs;
+ outputs.reserve(output_tensors->size);
+ // Make the TensorFlow lite inputs and outputs to ann_indices.
+ for (int i : TfLiteIntArrayView(input_tensors))
+ inputs.push_back(operand_mapping_.lite_index_to_ann(i));
+ for (int i : TfLiteIntArrayView(output_tensors))
+ outputs.push_back(operand_mapping_.lite_index_to_ann(i));
+ // Tell ANN to declare inputs/outputs
+ CHECK_NN(context, ANeuralNetworksModel_identifyInputsAndOutputs(
+ nn_model_.get(), inputs.size(), inputs.data(),
+ outputs.size(), outputs.data()));
+ // Finalize the model
+ CHECK_NN(context, ANeuralNetworksModel_finish(nn_model_.get()));
+
+ return kTfLiteOk;
+ }
+};
+
+} // namespace
+
+// Return a NN API Delegate struct that can check for support of ops.
+TfLiteDelegate* NnApiDelegate() {
+ static TfLiteDelegate delegate = {
+ .data_ = nullptr,
+ .Prepare = [](TfLiteContext* context,
+ TfLiteDelegate* delegate) -> TfLiteStatus {
+ // Do not check nodes_ if NN API is unavailable.
+ if (!NNAPIExists()) return kTfLiteOk;
+
+ std::vector<int> supported_nodes(1);
+ // We don't care about all nodes_, we only care about ones in the
+ // current plan.
+ TfLiteIntArray* plan;
+ TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan));
+ int total_supported_nodes = 0;
+ // Check for every node if it is supported
+ // TODO(b/80625235): Fix this to do more careful checking of versioning.
+ for (int node_index : TfLiteIntArrayView(plan)) {
+ TfLiteNode* node;
+ TfLiteRegistration* registration;
+ TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration(
+ context, node_index, &node, &registration));
+ NNAPIDelegateKernel dummy_kernel;
+ if (dummy_kernel.Map(context, registration->builtin_code, node)) {
+ supported_nodes.push_back(node_index);
+ }
+ total_supported_nodes += 1;
+ }
+ // Put the size at the beginning of the array.
+ supported_nodes[0] = supported_nodes.size() - 1;
+
+ // NN API Delegate Registration (the pseudo kernel that will invoke NN
+ // API subgraphs)
+ static const TfLiteRegistration nnapi_delegate_kernel = {
+ .init = [](TfLiteContext* context, const char* buffer,
+ size_t length) -> void* {
+ const TfLiteDelegateParams* params =
+ reinterpret_cast<const TfLiteDelegateParams*>(buffer);
+ NNAPIDelegateKernel* kernel_state = new NNAPIDelegateKernel;
+ kernel_state->Init(context, params);
+ return kernel_state;
+ },
+
+ .free = [](TfLiteContext* context, void* buffer) -> void {
+ delete reinterpret_cast<NNAPIDelegateKernel*>(buffer);
+ },
+
+ .prepare = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ // Since the underlying resize happened ahead of delegation
+ // worked. This does nothing.
+ return kTfLiteOk;
+ },
+
+ .invoke = [](TfLiteContext* context,
+ TfLiteNode* node) -> TfLiteStatus {
+ NNAPIDelegateKernel* state =
+ reinterpret_cast<NNAPIDelegateKernel*>(node->user_data);
+ return state->Invoke(context, node);
+ },
+
+ .builtin_code = kTfLiteBuiltinDelegate,
+ };
+
+ // Request TFLite to partition the graph and make kernels
+ // for each independent subgraph a new nnapi_delegate_kernel.
+ context->ReplaceSubgraphsWithDelegateKernels(
+ context, nnapi_delegate_kernel,
+ reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()),
+ delegate);
+ return kTfLiteOk;
+ }};
+
+ return &delegate;
+}
+
+} // namespace tflite
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
new file mode 100644
index 0000000000..44cca2fd28
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h
@@ -0,0 +1,31 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
+#define TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
+
+#include "tensorflow/contrib/lite/context.h"
+
+namespace tflite {
+
+// Return a delegate that can be used to use the NN API.
+// e.g.
+// NnApiDelegate* delegate = NnApiDelegate();
+// interpreter->ModifyGraphWithDelegate(&delegate);
+// NnApiDelegate() returns a singleton, so you should not free this
+// pointer or worry about its lifetime.
+TfLiteDelegate* NnApiDelegate();
+} // namespace tflite
+
+#endif // TENSORFLOW_CONTRIB_LITE_DELEGATES_NNAPI_NNAPI_DELEGATE_H_
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
new file mode 100644
index 0000000000..ff2e721423
--- /dev/null
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -0,0 +1,82 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.h"
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FloatAddOpModel : public SingleOpModel {
+ public:
+ FloatAddOpModel(const TensorData& input1, const TensorData& input2,
+ const TensorData& output,
+ ActivationFunctionType activation_type) {
+ this->SetApplyDelegate([](Interpreter* interpreter) {
+ interpreter->ModifyGraphWithDelegate(NnApiDelegate());
+ });
+ input1_ = AddInput(input1);
+ input2_ = AddInput(input2);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions,
+ CreateAddOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input1_), GetShape(input2_)});
+ }
+
+ int input1() { return input1_; }
+ int input2() { return input2_; }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+
+ protected:
+ int input1_;
+ int input2_;
+ int output_;
+};
+
+// Do a test with the NN API using no activation.
+TEST(NNAPIDelegate, AddWithNoActivation) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3}));
+}
+
+// Do a test with the NN api with relu.
+TEST(NNAPIDelegate, AddWithRelu) {
+ FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU);
+ m.PopulateTensor<float>(m.input1(), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}