aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/fully_connected.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 09:57:48 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit6c89b5f07ea8ab73c29b8dde8fbdbd8289ade3c6 (patch)
tree6bde3c9fa95a0250d7cc8b44d5d0f053d5c7ad43 /tensorflow/contrib/lite/kernels/fully_connected.cc
parentcfbcc61eec4d703d34c758cd45469de22fad9740 (diff)
More un-fused quantized LSTM support in TFLite interpreter
PiperOrigin-RevId: 202496488
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc49
1 files changed, 34 insertions, 15 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index cc4ae6ec6e..3b203dd480 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -283,30 +283,49 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
int32_t input_offset = -input->params.zero_point;
int32_t filter_offset = -filter->params.zero_point;
int32_t output_offset = output->params.zero_point;
-#define TF_LITE_FULLY_CONNECTED(type) \
+#define TF_LITE_FULLY_CONNECTED(type, output_data_type) \
type::FullyConnected( \
GetTensorData<uint8_t>(input), GetTensorDims(input), input_offset, \
GetTensorData<uint8_t>(filter), GetTensorDims(filter), filter_offset, \
GetTensorData<int32_t>(bias), GetTensorDims(bias), output_offset, \
data->output_multiplier, data->output_shift, \
data->output_activation_min, data->output_activation_max, \
- GetTensorData<uint8_t>(output), GetTensorDims(output), gemm_context)
+ GetTensorData<output_data_type>(output), GetTensorDims(output), \
+ gemm_context)
if (kernel_type == kReference) {
- TF_LITE_FULLY_CONNECTED(reference_ops);
- } else if (kernel_type == kPie) {
- if (input->type == kTfLiteFloat32) {
- // Pie currently only supports quantized models and float inputs/outputs.
- TfLiteTensor* input_quantized =
- &context->tensors[node->temporaries->data[0]];
- return EvalPieQuantized(context, node, params, data, input, filter, bias,
- input_quantized, output);
- } else {
- // TODO(ahentz): we don't have a quantized version of the PIE kernels, so
- // we just defer to the MINI ones.
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(reference_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(reference_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
}
+ } else if (kernel_type == kPie && input->type == kTfLiteFloat32) {
+ // Pie currently only supports quantized models and float inputs/outputs.
+ TfLiteTensor* input_quantized =
+ &context->tensors[node->temporaries->data[0]];
+ return EvalPieQuantized(context, node, params, data, input, filter, bias,
+ input_quantized, output);
} else {
- TF_LITE_FULLY_CONNECTED(optimized_ops);
+ switch (output->type) {
+ case kTfLiteUInt8:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, uint8_t);
+ break;
+ case kTfLiteInt16:
+ TF_LITE_FULLY_CONNECTED(optimized_ops, int16_t);
+ break;
+ default:
+ context->ReportError(
+ context,
+ "Quantized FullyConnected expects output data type uint8 or int16");
+ return kTfLiteError;
+ }
}
#undef TF_LITE_FULLY_CONNECTED