diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-11 04:33:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 04:36:38 -0700 |
commit | 56646a1f5e6773c6637b2477670fcbc4385cf21b (patch) | |
tree | cb58ee18f12b00b41a7ec8338007113e92162f90 /tensorflow/contrib/lite/nnapi_delegate.cc | |
parent | 20b3d4d297318874fd9b94b6bbeb3f90064ca9d4 (diff) |
Add NNAPI 1.1 Div/Mul/Pad/Mean nodes.
PiperOrigin-RevId: 196240584
Diffstat (limited to 'tensorflow/contrib/lite/nnapi_delegate.cc')
-rw-r--r-- | tensorflow/contrib/lite/nnapi_delegate.cc | 63 |
1 files changed, 58 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 1810dfae32..d99c88a26d 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -23,6 +23,10 @@ limitations under the License. #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h" +#ifdef __ANDROID__ +#include <sys/system_properties.h> +#endif + namespace tflite { // TODO(aselle): FATAL leaves resources hanging. @@ -46,6 +50,32 @@ void FATAL(const char* format, ...) { FATAL("Aborting since tflite returned failure."); \ } +namespace { + +int32_t GetAndroidSdkVersion() { +#ifdef __ANDROID__ + const char* sdkProp = "ro.build.version.sdk"; + char sdkVersion[PROP_VALUE_MAX]; + int length = __system_property_get(sdkProp, sdkVersion); + if (length != 0) { + for (int i = 0; i < length; ++i) { + int digit = sdkVersion[i] - '0'; + if (digit < 0 || digit > 9) { + // Non-numeric SDK version, assume it's higher then expected; + return 0xFFFF; + } + } + return atoi(sdkVersion); + } + FATAL("No %s prop", sdkProp); +#endif // __ANDROID__ + return 0; +} + +static const int32_t kAndroidSdkVersion = GetAndroidSdkVersion(); + +} // namespace + NNAPIAllocation::NNAPIAllocation(const char* filename, ErrorReporter* error_reporter) : MMAPAllocation(filename, error_reporter) { @@ -245,6 +275,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_float32(builtin->proj_clip); }; + auto add_mean_params = [&add_scalar_int32](void* data) { + auto builtin = reinterpret_cast<TfLiteMeanParams*>(data); + add_scalar_int32(builtin->keep_dims); + }; + #if 0 auto add_reshape_params = [&](void* data) { auto builtin = reinterpret_cast<TfLiteReshapeParams*>(data); @@ -262,8 +297,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; #endif - + int nnapi_version = 10; ANeuralNetworksOperationType nn_op_type; + switch (builtin) { case tflite::BuiltinOperator_ADD: nn_op_type = ANEURALNETWORKS_ADD; @@ -337,6 +373,23 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nn_op_type = ANEURALNETWORKS_LSTM; break; } + case tflite::BuiltinOperator_PAD: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_PAD; + break; + case tflite::BuiltinOperator_MEAN: + nnapi_version = 11; // require NNAPI 1.1 + add_mean_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_MEAN; + break; + case tflite::BuiltinOperator_DIV: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_DIV; + break; + case tflite::BuiltinOperator_SUB: + nnapi_version = 11; // require NNAPI 1.1 + nn_op_type = ANEURALNETWORKS_SUB; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: case tflite::BuiltinOperator_SVDF: @@ -350,7 +403,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: - case tflite::BuiltinOperator_PAD: case tflite::BuiltinOperator_PADV2: case tflite::BuiltinOperator_RESIZE_BILINEAR: case tflite::BuiltinOperator_CALL: @@ -361,9 +413,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: - case tflite::BuiltinOperator_MEAN: - case tflite::BuiltinOperator_DIV: - case tflite::BuiltinOperator_SUB: case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: @@ -393,6 +442,10 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, break; } + if (nnapi_version == 11 && kAndroidSdkVersion < 28) { + FATAL("Op %d needs NNAPI1.1", builtin); + } + // Add the operation. CHECK_NN(ANeuralNetworksModel_addOperation( nn_model, nn_op_type, static_cast<uint32_t>(augmented_inputs.size()), |