aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/nnapi_delegate.cc
diff options
context:
space:
mode:
authorGravatar Jared Duke <jdduke@google.com>2018-06-27 09:40:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 09:44:48 -0700
commit80bc59b99bca7f9bc167975bab1c295bc4793c9a (patch)
tree04fac0d6d231cfc19de24b3370909212b0170471 /tensorflow/contrib/lite/nnapi_delegate.cc
parent67cd3f7e5b63c69e447421587fe86f771e700448 (diff)
Add NNAPI squeeze op support
PiperOrigin-RevId: 202323072
Diffstat (limited to 'tensorflow/contrib/lite/nnapi_delegate.cc')
-rw-r--r--tensorflow/contrib/lite/nnapi_delegate.cc25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc
index 748c2f1a04..7627d89c09 100644
--- a/tensorflow/contrib/lite/nnapi_delegate.cc
+++ b/tensorflow/contrib/lite/nnapi_delegate.cc
@@ -215,6 +215,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
augmented_inputs.push_back(next_id++);
};
+ auto add_vector_int32 = [&](const int* values, uint32_t num_values) {
+ ANeuralNetworksOperandType operand_type{
+ .type = ANEURALNETWORKS_TENSOR_INT32,
+ .dimensionCount = 1,
+ .dimensions = &num_values};
+ CHECK_NN(ANeuralNetworksModel_addOperand(nn_model, &operand_type))
+ CHECK_NN(ANeuralNetworksModel_setOperandValue(
+ nn_model, next_id, values, sizeof(int32_t) * num_values));
+ augmented_inputs.push_back(next_id++);
+ };
+
// Handle state tensors of RNN, LSTM, SVDF.
// For each state_out tensor, a corresponding state_in operand needs to be
// created for NNAPI.
@@ -327,6 +338,14 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
add_scalar_int32(builtin->activation);
};
+ auto add_squeeze_params = [&](void* data) {
+ const auto* builtin = reinterpret_cast<TfLiteSqueezeParams*>(data);
+ // Note that we add the squeeze dimensions even if the dimensions were
+ // unspecified (empty), as NNAPI requires the operand.
+ add_vector_int32(builtin->squeeze_dims,
+ static_cast<uint32_t>(builtin->num_squeeze_dims));
+ };
+
// Handle optional input tensors.
auto add_optional_tensors = [&nn_model, &augmented_inputs,
&next_id](int nn_type) {
@@ -453,6 +472,11 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
nnapi_version = 11; // require NNAPI 1.1
nn_op_type = ANEURALNETWORKS_SUB;
break;
+ case tflite::BuiltinOperator_SQUEEZE:
+ nnapi_version = 11; // requires NNAPI 1.1
+ add_squeeze_params(node.builtin_data);
+ nn_op_type = ANEURALNETWORKS_SQUEEZE;
+ break;
case tflite::BuiltinOperator_CONCAT_EMBEDDINGS:
case tflite::BuiltinOperator_LSH_PROJECTION:
case tflite::BuiltinOperator_HASHTABLE_LOOKUP:
@@ -474,7 +498,6 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
case tflite::BuiltinOperator_TOPK_V2:
case tflite::BuiltinOperator_TRANSPOSE:
case tflite::BuiltinOperator_SPLIT:
- case tflite::BuiltinOperator_SQUEEZE:
case tflite::BuiltinOperator_STRIDED_SLICE:
case tflite::BuiltinOperator_EXP:
case tflite::BuiltinOperator_LOG_SOFTMAX: