diff options
author | 2018-07-09 15:51:18 -0700 | |
---|---|---|
committer | 2018-07-09 15:54:45 -0700 | |
commit | bc2674d09efbd87ae81ae41b81f1d152f37fac2a (patch) | |
tree | 03682ab97190210cb3426d0957fe1c1e76ad7648 /tensorflow/contrib/lite/delegates | |
parent | 3eb314bbcdc6e850eefc21c7ecbd91e28bd04a70 (diff) |
Add NNAPI transpose op support
PiperOrigin-RevId: 203845358
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
-rw-r--r-- | tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc | 16 | ||||
-rw-r--r-- | tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc | 94 |
2 files changed, 69 insertions, 41 deletions
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc index fd798c209e..f0d16575ec 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc @@ -452,6 +452,22 @@ class NNAPIDelegateKernel { } else { return nullptr; } + case kTfLiteBuiltinTranspose: + // Transpose requires NNAPI1.1. Also note that the permutation input + // tensor value dictates the output dimensions. + // TODO(b/110888333): Support dynamically-sized tensors in delegates. + if ((version == 1) && + (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) && + (node->inputs->size > 1) && + (context->tensors[node->inputs->data[1]].allocation_type == + kTfLiteMmapRo)) { + return [](TfLiteContext* context, NNAPIOpBuilder* builder, + TfLiteNode* node) -> ANeuralNetworksOperationType { + return ANEURALNETWORKS_TRANSPOSE; + }; + } else { + return nullptr; + } break; default: return nullptr; diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc index aad10c9ce7..ab2181e8ff 100644 --- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc +++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc @@ -27,14 +27,20 @@ using ::testing::ElementsAreArray; // TODO(b/110368244): figure out how to share the existing tests in kernels/ but // with the delegation on. Also, add more unit tests to improve code coverage. -class FloatAddOpModel : public SingleOpModel { +class SingleOpModelWithNNAPI : public SingleOpModel { + public: + SingleOpModelWithNNAPI() { + this->SetApplyDelegate([](Interpreter* interpreter) { + interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false); + }); + } +}; + +class FloatAddOpModel : public SingleOpModelWithNNAPI { public: FloatAddOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, ActivationFunctionType activation_type) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); @@ -81,9 +87,6 @@ class FloatMulOpModel : public SingleOpModel { FloatMulOpModel(const TensorData& input1, const TensorData& input2, const TensorData& output, ActivationFunctionType activation_type) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); input1_ = AddInput(input1); input2_ = AddInput(input2); output_ = AddOutput(output); @@ -114,15 +117,11 @@ TEST(NNAPIDelegate, MulWithNoActivation) { ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); } -class FloatPoolingOpModel : public SingleOpModel { +class FloatPoolingOpModel : public SingleOpModelWithNNAPI { public: FloatPoolingOpModel(BuiltinOperator type, const TensorData& input, int filter_width, int filter_height, const TensorData& output) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); output_ = AddOutput(output); @@ -193,10 +192,6 @@ class BaseConvolutionOpModel : public SingleOpModel { enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE, int dilation_width_factor = 1, int dilation_height_factor = 1) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); filter_ = AddInput(filter); @@ -344,14 +339,10 @@ TEST(NNAPIDelegate, Conv2DWithNoActivation) { })); } -class DepthwiseConvolutionOpModel : public SingleOpModel { +class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI { public: DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter, const TensorData& output) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); filter_ = AddInput(filter); @@ -426,15 +417,11 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) { })); } -class FloatFullyConnectedOpModel : public SingleOpModel { +class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI { public: FloatFullyConnectedOpModel(int units, int batches, const TensorData& input, const TensorData& output = {TensorType_FLOAT32}) : batches_(batches), units_(units) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - int total_input_size = 1; for (int i = 0; i < input.shape.size(); ++i) { total_input_size *= input.shape[i]; @@ -515,14 +502,10 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) { EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60)); } -class SoftmaxOpModel : public SingleOpModel { +class SoftmaxOpModel : public SingleOpModelWithNNAPI { public: SoftmaxOpModel(int batches, int size, float beta) : batches_(batches), input_size_(size), beta_(beta) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions, @@ -566,14 +549,10 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) { 1e-6))); } -class ReshapeOpModel : public SingleOpModel { +class ReshapeOpModel : public SingleOpModelWithNNAPI { public: ReshapeOpModel(std::initializer_list<int> input_shape, std::initializer_list<int> new_shape) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(TensorType_FLOAT32); new_shape_ = AddInput(TensorType_INT32); output_ = AddOutput(TensorType_FLOAT32); @@ -605,14 +584,10 @@ TEST(NNAPIDelegate, ReshapeSimpleTest) { EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2})); } -class SqueezeOpModel : public SingleOpModel { +class SqueezeOpModel : public SingleOpModelWithNNAPI { public: SqueezeOpModel(const TensorData& input, const TensorData& output, std::initializer_list<int> axis) { - this->SetApplyDelegate([](Interpreter* interpreter) { - interpreter->ModifyGraphWithDelegate(NnApiDelegate()); - }); - input_ = AddInput(input); output_ = AddOutput(output); SetBuiltinOp( @@ -666,6 +641,43 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) { 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); } +class TransposeSimpleModel : public SingleOpModelWithNNAPI { + public: + TransposeSimpleModel(std::initializer_list<int> input_shape, + std::initializer_list<int> perm_shape, + std::initializer_list<int> perm) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddConstInput(TensorType_INT32, perm, perm_shape); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape, perm_shape}); + } + + void SetInput(std::initializer_list<float> data) { + PopulateTensor<float>(input_, data); + } + + std::vector<float> GetOutput() { return ExtractVector<float>(output_); } + std::vector<int> GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int perm_; + int output_; +}; + +TEST(NNAPIDelegate, TransposeSimpleTest) { + TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + } // namespace } // namespace tflite |