diff options
author | 2018-06-27 09:40:26 -0700 | |
---|---|---|
committer | 2018-06-27 09:44:48 -0700 | |
commit | 80bc59b99bca7f9bc167975bab1c295bc4793c9a (patch) | |
tree | 04fac0d6d231cfc19de24b3370909212b0170471 /tensorflow/contrib/lite/nnapi_delegate.cc | |
parent | 67cd3f7e5b63c69e447421587fe86f771e700448 (diff) |
Add NNAPI squeeze op support
PiperOrigin-RevId: 202323072
Diffstat (limited to 'tensorflow/contrib/lite/nnapi_delegate.cc')
-rw-r--r-- | tensorflow/contrib/lite/nnapi_delegate.cc | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 748c2f1a04..7627d89c09 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -215,6 +215,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, augmented_inputs.push_back(next_id++); }; + auto add_vector_int32 = [&](const int* values, uint32_t num_values) { + ANeuralNetworksOperandType operand_type{ + .type = ANEURALNETWORKS_TENSOR_INT32, + .dimensionCount = 1, + .dimensions = &num_values}; + CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type)) + CHECK_NN(ANeuralNetworksModel_setOperandValue( + nn_model, next_id, values, sizeof(int32_t) * num_values)); + augmented_inputs.push_back(next_id++); + }; + // Handle state tensors of RNN, LSTM, SVDF. // For each state_out tensor, a corresponding state_in operand needs to be // created for NNAPI. @@ -327,6 +338,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, add_scalar_int32(builtin->activation); }; + auto add_squeeze_params = [&](void* data) { + const auto* builtin = reinterpret_cast<TfLiteSqueezeParams*>(data); + // Note that we add the squeeze dimensions even if the dimensions were + // unspecified (empty), as NNAPI requires the operand. + add_vector_int32(builtin->squeeze_dims, + static_cast<uint32_t>(builtin->num_squeeze_dims)); + }; + // Handle optional input tensors. auto add_optional_tensors = [&nn_model, &augmented_inputs, &next_id](int nn_type) { @@ -453,6 +472,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, nnapi_version = 11; // require NNAPI 1.1 nn_op_type = ANEURALNETWORKS_SUB; break; + case tflite::BuiltinOperator_SQUEEZE: + nnapi_version = 11; // requires NNAPI 1.1 + add_squeeze_params(node.builtin_data); + nn_op_type = ANEURALNETWORKS_SQUEEZE; + break; case tflite::BuiltinOperator_CONCAT_EMBEDDINGS: case tflite::BuiltinOperator_LSH_PROJECTION: case tflite::BuiltinOperator_HASHTABLE_LOOKUP: @@ -474,7 +498,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_SPLIT: - case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: |