diff options
author | 2018-06-28 09:57:48 -0700 | |
---|---|---|
committer | 2018-06-28 21:37:43 -0700 | |
commit | 6c89b5f07ea8ab73c29b8dde8fbdbd8289ade3c6 (patch) | |
tree | 6bde3c9fa95a0250d7cc8b44d5d0f053d5c7ad43 /tensorflow/contrib/lite/kernels/fully_connected.cc | |
parent | cfbcc61eec4d703d34c758cd45469de22fad9740 (diff) |
More un-fused quantized LSTM support in TFLite interpreter
PiperOrigin-RevId: 202496488
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/fully_connected.cc | 49 |
1 files changed, 34 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index cc4ae6ec6e..3b203dd480 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -283,30 +283,49 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, int32_t input_offset = -input->params.zero_point; int32_t filter_offset = -filter->params.zero_point; int32_t output_offset = output->params.zero_point; -#define TF_LITE_FULLY_CONNECTED(type) \ +#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \ type::FullyConnected( \ GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \ GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \ GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \ data->output_multiplier, data->output_shift, \ data->output_activation_min, data->output_activation_max, \ - GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context) + GetTensorData<output_data_type>(output), GetTensorDims(output), \ + gemm_context) if (kernel_type == kReference) { - TF_LITE_FULLY_CONNECTED(reference_ops); - } else if (kernel_type == kPie) { - if (input->type == kTfLiteFloat32) { - // Pie currently only supports quantized models and float inputs/outputs. - TfLiteTensor* input_quantized = - &context->tensors[node->temporaries->data[0]]; - return EvalPieQuantized(context, node, params, data, input, filter, bias, - input_quantized, output); - } else { - // TODO(ahentz): we don't have a quantized version of the PIE kernels, so - // we just defer to the MINI ones. - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(reference_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; } + } else if (kernel_type == kPie && input->type == kTfLiteFloat32) { + // Pie currently only supports quantized models and float inputs/outputs. + TfLiteTensor* input_quantized = + &context->tensors[node->temporaries->data[0]]; + return EvalPieQuantized(context, node, params, data, input, filter, bias, + input_quantized, output); } else { - TF_LITE_FULLY_CONNECTED(optimized_ops); + switch (output->type) { + case kTfLiteUInt8: + TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t); + break; + case kTfLiteInt16: + TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t); + break; + default: + context->ReportError( + context, + "Quantized FullyConnected expects output data type uint8 or int16"); + return kTfLiteError; + } } #undef TF_LITE_FULLY_CONNECTED |