aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/nnapi_delegate.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-11 04:33:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 04:36:38 -0700
commit56646a1f5e6773c6637b2477670fcbc4385cf21b (patch)
treecb58ee18f12b00b41a7ec8338007113e92162f90 /tensorflow/contrib/lite/nnapi_delegate.cc
parent20b3d4d297318874fd9b94b6bbeb3f90064ca9d4 (diff)
Add NNAPI 1.1 Div/Mul/Pad/Mean nodes.
PiperOrigin-RevId: 196240584
Diffstat (limited to 'tensorflow/contrib/lite/nnapi_delegate.cc')
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc63
1 files changed, 58 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 1810dfae32..d99c88a26d 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -23,6 +23,10 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h"
+#ifdef __ANDROID__
+#include <sys/system_properties.h>
+#endif
+
namespace tflite {
// TODO(aselle): FATAL leaves resources hanging.
@@ -46,6 +50,32 @@ void FATAL(const char* format, ...) {
FATAL("Aborting since tflite returned failure."); \
}
+namespace {
+
+int32_t GetAndroidSdkVersion() {
+#ifdef __ANDROID__
+ const char* sdkProp = "ro.build.version.sdk";
+ char sdkVersion[PROP_VALUE_MAX];
+ int length = __system_property_get(sdkProp, sdkVersion);
+ if (length != 0) {
+ for (int i = 0; i < length; ++i) {
+ int digit = sdkVersion[i] - '0';
+ if (digit < 0 || digit > 9) {
+ // Non-numeric SDK version, assume it's higher then expected;
+ return 0xFFFF;
+ }
+ }
+ return atoi(sdkVersion);
+ }
+ FATAL("No %s prop", sdkProp);
+#endif // __ANDROID__
+ return 0;
+}
+
+static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion();
+
+} // namespace
+
NNAPIAllocation::NNAPIAllocation(const char* filename,
ErrorReporter* error_reporter)
: MMAPAllocation(filename, error_reporter) {
@@ -245,6 +275,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_scalar_float32(builtin->proj_clip);
};
+ auto add_mean_params = [&add_scalar_int32](void* data) {
+ auto builtin = reinterpret_cast<TfLiteMeanParams*>(data);
+ add_scalar_int32(builtin->keep_dims);
+ };
+
#if 0
auto add_reshape_params = [&](void* data) {
auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data);
@@ -262,8 +297,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
augmented_inputs.push_back(next_id++);
};
#endif
-
+ int nnapi_version = 10;
ANeuralNetworksOperationType nn_op_type;
+
switch (builtin) {
case tflite::BuiltinOperator_ADD:
nn_op_type = ANEURALNETWORKS_ADD;
@@ -337,6 +373,23 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
nn_op_type = ANEURALNETWORKS_LSTM;
break;
}
+ case tflite::BuiltinOperator_PAD:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_PAD;
+ break;
+ case tflite::BuiltinOperator_MEAN:
+ nnapi_version = 11; // require NNAPI 1.1
+ add_mean_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_MEAN;
+ break;
+ case tflite::BuiltinOperator_DIV:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_DIV;
+ break;
+ case tflite::BuiltinOperator_SUB:
+ nnapi_version = 11; // require NNAPI 1.1
+ nn_op_type = ANEURALNETWORKS_SUB;
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_SVDF:
@@ -350,7 +403,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case tflite::BuiltinOperator_L2_NORMALIZATION:
case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION:
- case tflite::BuiltinOperator_PAD:
case tflite::BuiltinOperator_PADV2:
case tflite::BuiltinOperator_RESIZE_BILINEAR:
case tflite::BuiltinOperator_CALL:
@@ -361,9 +413,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
case tflite::BuiltinOperator_TOPK_V2:
case tflite::BuiltinOperator_TRANSPOSE:
- case tflite::BuiltinOperator_MEAN:
- case tflite::BuiltinOperator_DIV:
- case tflite::BuiltinOperator_SUB:
case tflite::BuiltinOperator_SPLIT:
case tflite::BuiltinOperator_SQUEEZE:
case tflite::BuiltinOperator_STRIDED_SLICE:
@@ -393,6 +442,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
break;
}
+ if (nnapi_version == 11 && kAndroidSdkVersion < 28) {
+ FATAL("Op %d needs NNAPI1.1", builtin);
+ }
+
// Add the operation.
CHECK_NN(ANeuralNetworksModel_addOperation(
nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()),