From 3c0c74e0147ef284a6f2cc5533bea8777af1e740 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 18 Jun 2018 17:34:45 -0700 Subject: Make NNAPI delegation support more ops. PiperOrigin-RevId: 201090056 --- .../contrib/lite/delegates/nnapi/nnapi_delegate.cc | 253 ++++++++-- .../lite/delegates/nnapi/nnapi_delegate_test.cc | 533 +++++++++++++++++++++ 2 files changed, 745 insertions(+), 41 deletions(-) (limited to 'tensorflow/contrib/lite/delegates/nnapi') diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index 0731d14419..e96ee92376 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -26,6 +26,10 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/kernel_util.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include +#endif + namespace tflite { namespace { @@ -37,6 +41,29 @@ namespace { return kTfLiteError; \ } +namespace { +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return std::numeric_limits::max(); + } + } + return atoi(sdkVersion); + } +#endif // __ANDROID__ + return 0; +} + +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); +} // namespace + // RAII NN API Model Destructor for use with std::unique_ptr struct NNFreeModel { void operator()(ANeuralNetworksModel* model) { @@ -71,7 +98,7 @@ class OperandMapping { // Add a new mapping from `tflite_index` and return the NN API tensor index. int add_new_ann_tensor_index(int tflite_index) { if (tflite_index >= lite_tensor_to_ann_tensor_.size()) { - lite_tensor_to_ann_tensor_.resize(tflite_index + 1); + lite_tensor_to_ann_tensor_.resize(tflite_index + 1, -1); } int new_tensor_index = next_ann_tensor_index_++; lite_tensor_to_ann_tensor_[tflite_index] = new_tensor_index; @@ -98,14 +125,22 @@ class NNAPIOpBuilder { operand_mapping_(tensor_mapping), nn_model_(nn_model) {} - TfLiteStatus AddScalarInt32Operand(int value) { - ANeuralNetworksOperandType operand_type{.type = ANEURALNETWORKS_INT32}; - CHECK_NN(context_, - ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); - int ann_operand = operand_mapping_->add_new_non_tensor_operand(); - CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( - nn_model_, ann_operand, &value, sizeof(int32_t))); - augmented_inputs_.push_back(ann_operand); + TfLiteStatus AddScalarInt32Operand(int32_t value) { + return AddScalarOperand(value, ANEURALNETWORKS_INT32); + } + + TfLiteStatus AddScalarFloat32Operand(float value) { + return AddScalarOperand(value, ANEURALNETWORKS_FLOAT32); + } + + TfLiteStatus AddPoolingParams(void* data) { + auto builtin = reinterpret_cast(data); + AddScalarInt32Operand(builtin->padding); + AddScalarInt32Operand(builtin->stride_width); + AddScalarInt32Operand(builtin->stride_height); + AddScalarInt32Operand(builtin->filter_width); + AddScalarInt32Operand(builtin->filter_height); + AddScalarInt32Operand(builtin->activation); return kTfLiteOk; } @@ -149,7 +184,6 @@ class NNAPIOpBuilder { return kTfLiteOk; case kTfLiteFloat32: nn_type = ANEURALNETWORKS_TENSOR_FLOAT32; - scale = 0.f; break; case kTfLiteUInt8: nn_type = ANEURALNETWORKS_TENSOR_QUANT8_ASYMM; @@ -158,8 +192,8 @@ class NNAPIOpBuilder { break; case kTfLiteInt32: nn_type = ANEURALNETWORKS_TENSOR_INT32; - scale = 0.f; - zeroPoint = 0; + scale = tensor->params.scale; + zeroPoint = tensor->params.zero_point; break; default: context_->ReportError(context_, "Logic error in NN API Delegate.\n"); @@ -192,12 +226,24 @@ class NNAPIOpBuilder { augmented_inputs_.data(), static_cast(augmented_outputs_.size()), augmented_outputs_.data())); - augmented_outputs_.clear(); + augmented_inputs_.clear(); augmented_outputs_.clear(); return kTfLiteOk; } private: + template + TfLiteStatus AddScalarOperand(T value, int32_t nn_type) { + ANeuralNetworksOperandType operand_type{.type = nn_type}; + CHECK_NN(context_, + ANeuralNetworksModel_addOperand(nn_model_, &operand_type)); + int ann_operand = operand_mapping_->add_new_non_tensor_operand(); + CHECK_NN(context_, ANeuralNetworksModel_setOperandValue( + nn_model_, ann_operand, &value, sizeof(T))); + augmented_inputs_.push_back(ann_operand); + return kTfLiteOk; + } + // TfLiteContext for error handling. Must be named context for macros to // work. TfLiteContext* context_; @@ -227,29 +273,143 @@ class NNAPIDelegateKernel { // Return a function that knows how to translate a node into its operands // when called. You can use this function to see if a node is supported // (i.e. that MappingFn is not nullptr). - MappingFn Map(TfLiteContext* context, int builtin_code, TfLiteNode* node) { + MappingFn Map(TfLiteContext* context, int builtin_code, int version, + TfLiteNode* node) { switch (builtin_code) { case kTfLiteBuiltinAdd: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { - auto builtin = reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_ADD; - }; + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_ADD; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMul: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_MUL; + }; + } else { + return nullptr; + } break; case kTfLiteBuiltinAveragePool2d: - return [](TfLiteContext* context, NNAPIOpBuilder* builder, - TfLiteNode* node) -> ANeuralNetworksOperationType { + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_AVERAGE_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinMaxPool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_MAX_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinL2Pool2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + builder->AddPoolingParams(node->builtin_data); + return ANEURALNETWORKS_L2_POOL_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinConv2d: + if (version == 1) { auto builtin = - reinterpret_cast(node->builtin_data); - builder->AddScalarInt32Operand(builtin->padding); - builder->AddScalarInt32Operand(builtin->stride_width); - builder->AddScalarInt32Operand(builtin->stride_height); - builder->AddScalarInt32Operand(builtin->filter_width); - builder->AddScalarInt32Operand(builtin->filter_height); - builder->AddScalarInt32Operand(builtin->activation); - return ANEURALNETWORKS_AVERAGE_POOL_2D; - }; + reinterpret_cast(node->builtin_data); + if (builtin->dilation_width_factor != 1 || + builtin->dilation_height_factor != 1 || node->inputs->size != 3) { + // NNAPI does not support dilated Conv2D. + return nullptr; + } + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinDepthwiseConv2d: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->padding); + builder->AddScalarInt32Operand(builtin->stride_width); + builder->AddScalarInt32Operand(builtin->stride_height); + builder->AddScalarInt32Operand(builtin->depth_multiplier); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_DEPTHWISE_CONV_2D; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinFullyConnected: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = reinterpret_cast( + node->builtin_data); + builder->AddScalarInt32Operand(builtin->activation); + return ANEURALNETWORKS_FULLY_CONNECTED; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinSoftmax: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + auto builtin = + reinterpret_cast(node->builtin_data); + builder->AddScalarFloat32Operand(builtin->beta); + return ANEURALNETWORKS_SOFTMAX; + }; + } else { + return nullptr; + } + break; + case kTfLiteBuiltinReshape: + if (version == 1) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_RESHAPE; + }; + } else { + return nullptr; + } break; default: return nullptr; @@ -292,10 +452,14 @@ class NNAPIDelegateKernel { int relative_input_index = 0; for (auto absolute_input_index : TfLiteIntArrayView(node->inputs)) { TfLiteTensor* tensor = &context->tensors[absolute_input_index]; - CHECK_NN(context, ANeuralNetworksExecution_setInput( - execution, relative_input_index, nullptr, - tensor->data.raw, tensor->bytes)); - relative_input_index++; + // TODO(miaowang): make sure the delegation works with dequantized weights + // as intermediate tensors. + if (tensor->allocation_type != kTfLiteMmapRo) { + CHECK_NN(context, ANeuralNetworksExecution_setInput( + execution, relative_input_index, nullptr, + tensor->data.raw, tensor->bytes)); + relative_input_index++; + } } // Set the output tensor buffers. @@ -345,8 +509,8 @@ class NNAPIDelegateKernel { TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index)); } // Get op type and operands - int nn_op_type = - Map(context, reg->builtin_code, node)(context, &builder, node); + int nn_op_type = Map(context, reg->builtin_code, reg->version, node)( + context, &builder, node); // Map outputs to NN API tensor indices. for (auto output_index : TfLiteIntArrayView(node->outputs)) { TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index)); @@ -368,8 +532,12 @@ class NNAPIDelegateKernel { std::vector outputs; outputs.reserve(output_tensors->size); // Make the TensorFlow lite inputs and outputs to ann_indices. - for (int i : TfLiteIntArrayView(input_tensors)) - inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + for (int i : TfLiteIntArrayView(input_tensors)) { + // Constant tensors are not NNAPI inputs. + if (context->tensors[i].allocation_type != kTfLiteMmapRo) { + inputs.push_back(operand_mapping_.lite_index_to_ann(i)); + } + } for (int i : TfLiteIntArrayView(output_tensors)) outputs.push_back(operand_mapping_.lite_index_to_ann(i)); // Tell ANN to declare inputs/outputs @@ -392,7 +560,8 @@ TfLiteDelegate* NnApiDelegate() { .Prepare = [](TfLiteContext* context, TfLiteDelegate* delegate) -> TfLiteStatus { // Do not check nodes_ if NN API is unavailable. - if (!NNAPIExists()) return kTfLiteOk; + // NN API is only available since Android O-MR1 (API 27). + if (kAndroidSdkVersion < 27 || !NNAPIExists()) return kTfLiteOk; std::vector supported_nodes(1); // We don't care about all nodes_, we only care about ones in the @@ -400,6 +569,7 @@ TfLiteDelegate* NnApiDelegate() { TfLiteIntArray* plan; TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &plan)); int total_supported_nodes = 0; + // Check for every node if it is supported // TODO(b/80625235): Fix this to do more careful checking of versioning. for (int node_index : TfLiteIntArrayView(plan)) { @@ -408,7 +578,8 @@ TfLiteDelegate* NnApiDelegate() { TF_LITE_ENSURE_STATUS(context->GetNodeAndRegistration( context, node_index, &node, ®istration)); NNAPIDelegateKernel dummy_kernel; - if (dummy_kernel.Map(context, registration->builtin_code, node)) { + if (dummy_kernel.Map(context, registration->builtin_code, + registration->version, node)) { supported_nodes.push_back(node_index); } total_supported_nodes += 1; diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index ff2e721423..799e3efe0b 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -21,8 +21,12 @@ limitations under the License. namespace tflite { namespace { +using ::testing::ElementsAre; using ::testing::ElementsAreArray; +// TODO(b/110368244): figure out how to share the existing tests in kernels/ but +// with the delegation on. Also, add more unit tests to improve code coverage. + class FloatAddOpModel : public SingleOpModel { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, @@ -72,6 +76,535 @@ TEST(NNAPIDelegate, AddWithRelu) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({0.0, 0.4, 1.0, 1.3})); } +class FloatMulOpModel : public SingleOpModel { + public: + FloatMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, + ActivationFunctionType activation_type) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + input1_ = AddInput(input1); + input2_ = AddInput(input2); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, + CreateMulOptions(builder_, activation_type).Union()); + BuildInterpreter({GetShape(input1_), GetShape(input2_)}); + } + + int input1() { return input1_; } + int input2() { return input2_; } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input1_; + int input2_; + int output_; +}; + +TEST(NNAPIDelegate, MulWithNoActivation) { + FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); + m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); +} + +class FloatPoolingOpModel : public SingleOpModel { + public: + FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, + int filter_width, int filter_height, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + output_ = AddOutput(output); + + SetBuiltinOp( + type, BuiltinOptions_Pool2DOptions, + CreatePool2DOptions(builder_, Padding_VALID, 2, 2, filter_width, + filter_height, ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int output_; +}; + +TEST(NNAPIDelegate, AveragePoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_AVERAGE_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2.75, 5.75})); +} + +TEST(NNAPIDelegate, MaxPoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_MAX_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 10})); +} + +TEST(NNAPIDelegate, L2PoolWithNoActivation) { + FloatPoolingOpModel m(BuiltinOperator_L2_POOL_2D, + /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}, + /*filter_width=*/2, /*filter_height=*/2, + /*output=*/{TensorType_FLOAT32, {}}); + m.SetInput({ + 0, 6, 2, 4, // + 3, 2, 10, 7, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5})); +} + +class BaseConvolutionOpModel : public SingleOpModel { + public: + BaseConvolutionOpModel( + const TensorData& input, const TensorData& filter, + const TensorData& output, int stride_width = 2, int stride_height = 2, + enum Padding padding = Padding_VALID, + enum ActivationFunctionType activation = ActivationFunctionType_NONE, + int dilation_width_factor = 1, int dilation_height_factor = 1) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[0]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + if (input.type != TensorType_FLOAT32) { + // The following is required by quantized inference. It is the unittest's + // responsibility to make sure the output scale falls into the correct + // range. + CHECK_LT(GetScale(input_) * GetScale(filter_), GetScale(output_)); + } + + SetBuiltinOp(BuiltinOperator_CONV_2D, BuiltinOptions_Conv2DOptions, + CreateConv2DOptions( + builder_, padding, stride_width, stride_height, activation, + dilation_width_factor, dilation_height_factor) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +class ConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } +}; + +class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { + public: + using BaseConvolutionOpModel::BaseConvolutionOpModel; + + void SetInput(std::initializer_list data) { + QuantizeAndPopulate(input_, data); + } + + void SetFilter(std::initializer_list data) { + QuantizeAndPopulate(filter_, data); + } + + void SetBias(std::initializer_list data) { + QuantizeAndPopulate(bias_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetDequantizedOutput() { + return Dequantize(ExtractVector(output_), + GetScale(output_), GetZeroPoint(output_)); + } +}; + +// In this tests we set the input and output scales so that the results +// match exactly the 'non-quantized' version. +TEST(NNAPIDelegate, SimpleTestQuantized) { + QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, + {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, + {TensorType_UINT8, {}, -127, 128}); + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + }, + 1e-5))); + // For good measure, let's also verify the quantized values: + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 145, 129, 132, // + 145, 129, 132, // + 144, 131, 130, // + 164, 131, 130, // + })); +} + +TEST(NNAPIDelegate, Conv2DWithNoActivation) { + ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, + {TensorType_FLOAT32, {3, 2, 2, 1}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + // First batch + 1, 1, 1, 1, // row = 1 + 2, 2, 2, 2, // row = 2 + // Second batch + 1, 2, 3, 4, // row = 1 + 1, 2, 3, 4, // row = 2 + }); + m.SetFilter({ + 1, 2, 3, 4, // first 2x2 filter + -1, 1, -1, 1, // second 2x2 filter + -1, -1, 1, 1, // third 2x2 filter + }); + m.SetBias({1, 2, 3}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 18, 2, 5, // first batch, left + 18, 2, 5, // first batch, right + 17, 4, 3, // second batch, left + 37, 4, 3, // second batch, right + })); +} + +class DepthwiseConvolutionOpModel : public SingleOpModel { + public: + DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, + const TensorData& output) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(input); + filter_ = AddInput(filter); + + int bias_size = GetShape(filter_)[3]; + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {bias_size}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(filter_); + TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + int input_depth = GetShape(input_)[3]; + int output_depth = GetShape(filter_)[3]; + int depth_mul = output_depth / input_depth; + + SetBuiltinOp( + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOptions_DepthwiseConv2DOptions, + CreateDepthwiseConv2DOptions(builder_, Padding_VALID, 1, 1, depth_mul, + ActivationFunctionType_NONE) + .Union()); + + BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); + } + + void SetFilter(std::initializer_list f) { PopulateTensor(filter_, f); } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int filter_; + int bias_; + int output_; +}; + +TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { + DepthwiseConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 2, 2}}, + {TensorType_FLOAT32, {1, 2, 2, 4}}, + {TensorType_FLOAT32, {}}); + + m.SetInput({ + 1, 2, 7, 8, // column 1 + 3, 4, 9, 10, // column 2 + 5, 6, 11, 12, // column 3 + }); + m.SetFilter({ + 1, 2, 3, 4, // + -9, 10, -11, 12, // + 5, 6, 7, 8, // + 13, -14, 15, -16, // + }); + m.SetBias({1, 2, 3, 4}); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 71, -34, 99, -20, // + 91, -26, 127, -4, // + })); +} + +class FloatFullyConnectedOpModel : public SingleOpModel { + public: + FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, + const TensorData& output = {TensorType_FLOAT32}) + : batches_(batches), units_(units) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + int total_input_size = 1; + for (int i = 0; i < input.shape.size(); ++i) { + total_input_size *= input.shape[i]; + } + input_size_ = total_input_size / batches_; + + input_ = AddInput(input); + weights_ = + AddInput({input.type, {units_, input_size_}, input.min, input.max}); + + if (input.type == TensorType_FLOAT32) { + bias_ = AddInput({TensorType_FLOAT32, {units_}}); + } else { + // This is a quantized version. The scale of 'bias' depends on the scales + // of input and filter. Supposedly this is correctly set during quantized + // training. + auto bias_scale = GetScale(input_) * GetScale(weights_); + TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale}; + bias_ = AddInput(bias); + } + + output_ = AddOutput(output); + + SetBuiltinOp( + BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions, + CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)}); + } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + protected: + int input_; + int weights_; + int bias_; + int output_; + + int batches_; + int units_; + int input_size_; +}; + +TEST(NNAPIDelegate, FullyConnectedSimpleTest) { + FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2, + /*input=*/{TensorType_FLOAT32, {2, 10}}); + m.SetWeights({ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1 + }); + m.SetBias({1, 2, 3}); + + m.SetInput({ + 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0 + 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); +} + +class SoftmaxOpModel : public SingleOpModel { + public: + SoftmaxOpModel(int batches, int size, float beta) + : batches_(batches), input_size_(size), beta_(beta) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, + CreateSoftmaxOptions(builder_, beta_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; + float beta_; +}; + +TEST(NNAPIDelegate, SoftmaxSimpleTest) { + SoftmaxOpModel m(/*batches=*/2, /*size=*/5, /*beta=*/1.0); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 0 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {0.011656231, 0.031684921, 0.086128544, 0.234121657, 0.636408647, + 0.636408647, 0.234121657, 0.086128544, 0.031684921, 0.011656231}, + 1e-6))); +} + +class ReshapeOpModel : public SingleOpModel { + public: + ReshapeOpModel(std::initializer_list input_shape, + std::initializer_list new_shape) { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate()); + }); + + input_ = AddInput(TensorType_FLOAT32); + new_shape_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_RESHAPE, BuiltinOptions_ReshapeOptions, + CreateReshapeOptions(builder_, builder_.CreateVector(new_shape)) + .Union()); + BuildInterpreter({input_shape, {static_cast(new_shape.size())}}); + PopulateTensor(new_shape_, new_shape); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int new_shape_; + int output_; +}; + +TEST(NNAPIDelegate, ReshapeSimpleTest) { + ReshapeOpModel m({1, 2, 4, 1}, {2, 2, 2}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); +} + } // namespace } // namespace tflite -- cgit v1.2.3