aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/delegates
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-07-09 15:51:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 15:54:45 -0700
commitbc2674d09efbd87ae81ae41b81f1d152f37fac2a (patch)
tree03682ab97190210cb3426d0957fe1c1e76ad7648 /tensorflow/contrib/lite/delegates
parent3eb314bbcdc6e850eefc21c7ecbd91e28bd04a70 (diff)
Add NNAPI transpose op support
PiperOrigin-RevId: 203845358
Diffstat (limited to 'tensorflow/contrib/lite/delegates')
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc16
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc94
2 files changed, 69 insertions, 41 deletions
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index fd798c209e..f0d16575ec 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -452,6 +452,22 @@ class NNAPIDelegateKernel {
} else {
return nullptr;
}
+ case kTfLiteBuiltinTranspose:
+ // Transpose requires NNAPI1.1. Also note that the permutation input
+ // tensor value dictates the output dimensions.
+ // TODO(b/110888333): Support dynamically-sized tensors in delegates.
+ if ((version == 1) &&
+ (kAndroidSdkVersion >= kMinSdkVersionForNNAPI11) &&
+ (node->inputs->size > 1) &&
+ (context->tensors[node->inputs->data[1]].allocation_type ==
+ kTfLiteMmapRo)) {
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_TRANSPOSE;
+ };
+ } else {
+ return nullptr;
+ }
break;
default:
return nullptr;
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index aad10c9ce7..ab2181e8ff 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -27,14 +27,20 @@ using ::testing::ElementsAreArray;
// TODO(b/110368244): figure out how to share the existing tests in kernels/ but
// with the delegation on. Also, add more unit tests to improve code coverage.
-class FloatAddOpModel : public SingleOpModel {
+class SingleOpModelWithNNAPI : public SingleOpModel {
+ public:
+ SingleOpModelWithNNAPI() {
+ this->SetApplyDelegate([](Interpreter* interpreter) {
+ interpreter->ModifyGraphWithDelegate(NnApiDelegate(), false);
+ });
+ }
+};
+
+class FloatAddOpModel : public SingleOpModelWithNNAPI {
public:
FloatAddOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -81,9 +87,6 @@ class FloatMulOpModel : public SingleOpModel {
FloatMulOpModel(const TensorData& input1, const TensorData& input2,
const TensorData& output,
ActivationFunctionType activation_type) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
input1_ = AddInput(input1);
input2_ = AddInput(input2);
output_ = AddOutput(output);
@@ -114,15 +117,11 @@ TEST(NNAPIDelegate, MulWithNoActivation) {
ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4})));
}
-class FloatPoolingOpModel : public SingleOpModel {
+class FloatPoolingOpModel : public SingleOpModelWithNNAPI {
public:
FloatPoolingOpModel(BuiltinOperator type, const TensorData& input,
int filter_width, int filter_height,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
@@ -193,10 +192,6 @@ class BaseConvolutionOpModel : public SingleOpModel {
enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
int dilation_width_factor = 1, int dilation_height_factor = 1) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -344,14 +339,10 @@ TEST(NNAPIDelegate, Conv2DWithNoActivation) {
}));
}
-class DepthwiseConvolutionOpModel : public SingleOpModel {
+class DepthwiseConvolutionOpModel : public SingleOpModelWithNNAPI {
public:
DepthwiseConvolutionOpModel(const TensorData& input, const TensorData& filter,
const TensorData& output) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -426,15 +417,11 @@ TEST(NNAPIDelegate, DepthwiseConv2DWithNoActivation) {
}));
}
-class FloatFullyConnectedOpModel : public SingleOpModel {
+class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI {
public:
FloatFullyConnectedOpModel(int units, int batches, const TensorData& input,
const TensorData& output = {TensorType_FLOAT32})
: batches_(batches), units_(units) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
int total_input_size = 1;
for (int i = 0; i < input.shape.size(); ++i) {
total_input_size *= input.shape[i];
@@ -515,14 +502,10 @@ TEST(NNAPIDelegate, FullyConnectedSimpleTest) {
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
-class SoftmaxOpModel : public SingleOpModel {
+class SoftmaxOpModel : public SingleOpModelWithNNAPI {
public:
SoftmaxOpModel(int batches, int size, float beta)
: batches_(batches), input_size_(size), beta_(beta) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
SetBuiltinOp(BuiltinOperator_SOFTMAX, BuiltinOptions_SoftmaxOptions,
@@ -566,14 +549,10 @@ TEST(NNAPIDelegate, SoftmaxSimpleTest) {
1e-6)));
}
-class ReshapeOpModel : public SingleOpModel {
+class ReshapeOpModel : public SingleOpModelWithNNAPI {
public:
ReshapeOpModel(std::initializer_list<int> input_shape,
std::initializer_list<int> new_shape) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(TensorType_FLOAT32);
new_shape_ = AddInput(TensorType_INT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -605,14 +584,10 @@ TEST(NNAPIDelegate, ReshapeSimpleTest) {
EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2, 2}));
}
-class SqueezeOpModel : public SingleOpModel {
+class SqueezeOpModel : public SingleOpModelWithNNAPI {
public:
SqueezeOpModel(const TensorData& input, const TensorData& output,
std::initializer_list<int> axis) {
- this->SetApplyDelegate([](Interpreter* interpreter) {
- interpreter->ModifyGraphWithDelegate(NnApiDelegate());
- });
-
input_ = AddInput(input);
output_ = AddOutput(output);
SetBuiltinOp(
@@ -666,6 +641,43 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+class TransposeSimpleModel : public SingleOpModelWithNNAPI {
+ public:
+ TransposeSimpleModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> perm_shape,
+ std::initializer_list<int> perm) {
+ input_ = AddInput(TensorType_FLOAT32);
+ perm_ = AddConstInput(TensorType_INT32, perm, perm_shape);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+ CreateTransposeOptions(builder_).Union());
+ BuildInterpreter({input_shape, perm_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int perm_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, TransposeSimpleTest) {
+ TransposeSimpleModel m({2, 3, 4}, {3}, {2, 0, 1});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
+ 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
} // namespace
} // namespace tflite