aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-23 08:03:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 08:06:09 -0700
commit32fe0302b0cf02d6cc3ae6cf67b233ad65c74bfe (patch)
treedf53fee69a2c7755c5aedb6d9cbdcb06347f9eb1 /tensorflow
parent8647db865ce41361413a2eb4c3b4d0ba404dd4e0 (diff)
Delegate L2Norm to nnapi.
PiperOrigin-RevId: 205661557
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc12
-rw-r--r--tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc35
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc9
3 files changed, 55 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
index f0d16575ec..0c7f6d3125 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate.cc
@@ -452,6 +452,18 @@ class NNAPIDelegateKernel {
} else {
return nullptr;
}
+ case kTfLiteBuiltinL2Normalization: {
+ auto builtin =
+ reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
+ if (builtin->activation != kTfLiteActNone) {
+ // NNAPI does not support activations
+ return nullptr;
+ }
+ return [](TfLiteContext* context, NNAPIOpBuilder* builder,
+ TfLiteNode* node) -> ANeuralNetworksOperationType {
+ return ANEURALNETWORKS_L2_NORMALIZATION;
+ };
+ }
case kTfLiteBuiltinTranspose:
// Transpose requires NNAPI1.1. Also note that the permutation input
// tensor value dictates the output dimensions.
diff --git a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
index ab2181e8ff..baf8046f9b 100644
--- a/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/contrib/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -641,6 +641,41 @@ TEST(NNAPIDelegate, SqueezeWithAxisTest) {
17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}));
}
+class L2NormOpModel : public SingleOpModelWithNNAPI {
+ public:
+ L2NormOpModel(const TensorData& input, const TensorData& output,
+ ActivationFunctionType activation_type) {
+ input_ = AddInput(input);
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
+ CreateL2NormOptions(builder_, activation_type).Union());
+ BuildInterpreter({GetShape(input_)});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int new_shape_;
+ int output_;
+};
+
+TEST(NNAPIDelegate, L2NormSimpleTest) {
+ std::initializer_list<float> data = {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1};
+ L2NormOpModel m({TensorType_FLOAT32, {1, 1, 1, 6}},
+ {TensorType_FLOAT32, {1, 1, 1, 6}},
+ ActivationFunctionType_NONE);
+ m.SetInput(data);
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 1, 1, 6}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
+}
+
class TransposeSimpleModel : public SingleOpModelWithNNAPI {
public:
TransposeSimpleModel(std::initializer_list<int> input_shape,
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 710ce1632e..659230e033 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -560,6 +560,14 @@ TfLiteStatus AddOpsAndParams(
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_TRANSPOSE;
break;
+ case tflite::BuiltinOperator_L2_NORMALIZATION:
+ nn_op_type = ANEURALNETWORKS_L2_NORMALIZATION;
+ if (reinterpret_cast<TfLiteL2NormParams*>(node.builtin_data)
+ ->activation != kTfLiteActNone) {
+ FATAL(
+ "NNAPI does not support L2Normalization with fused activations");
+ }
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
@@ -568,7 +576,6 @@ TfLiteStatus AddOpsAndParams(
case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE:
case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
- case tflite::BuiltinOperator_L2_NORMALIZATION:
case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
case tflite::BuiltinOperator_PADV2:
case tflite::BuiltinOperator_RESIZE_BILINEAR: