diff options
author | 2018-06-26 14:16:37 -0700 | |
---|---|---|
committer | 2018-06-28 21:37:43 -0700 | |
commit | 110ddc2103d7c86084ff52994998575113862542 (patch) | |
tree | 503f944630e7ea4d2cd9ca5fbca4621c7f555db6 /tensorflow/contrib/lite/kernels/fully_connected.cc | |
parent | 92221c68cdcf27607969089e5b6c06fdeeae8ae8 (diff) |
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op.
PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/fully_connected.cc | 69 |
1 files changed, 63 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc index f6fc0f5b6a..b40294709b 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected.cc @@ -63,6 +63,7 @@ constexpr int kInputTensor = 0; constexpr int kWeightsTensor = 1; constexpr int kBiasTensor = 2; constexpr int kOutputTensor = 0; +constexpr int kShuffledInputWorkspaceTensor = 1; constexpr int kScratchBufferTensor = 1; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -87,7 +88,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 3); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); + // Shuffled formats need a workspace to store the shuffled input activations. + const int expected_outputs_count = + params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1 + : 2; + TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor); @@ -121,9 +126,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { QuantizeMultiplierSmallerThanOneExp( real_multiplier, &data->output_multiplier, &data->output_shift); data->output_shift *= -1; - CalculateActivationRangeUint8(params->activation, output, - &data->output_activation_min, - &data->output_activation_max); + TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( + context, params->activation, output, &data->output_activation_min, + &data->output_activation_max)); } // If we have to perform on-the-fly quantization (with quantized weights and @@ -309,6 +314,44 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, } template <KernelType kernel_type> +TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteFullyConnectedParams* params, + OpData* data, const TfLiteTensor* input, + const TfLiteTensor* filter, + const TfLiteTensor* bias, + TfLiteTensor* output, + TfLiteTensor* shuffled_input_workspace) { + gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context); + + // TODO(b/110697972) decide more consistently if / how / where we want + // to perform this kind of runtime data type checks. + if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 || + bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 || + shuffled_input_workspace->type != kTfLiteUInt8) { + context->ReportError(context, "Unexpected data type"); + return kTfLiteError; + } + +#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \ + type::ShuffledFullyConnected( \ + GetTensorData<uint8_t>(input), GetTensorDims(input), \ + GetTensorData<uint8_t>(filter), GetTensorDims(filter), \ + GetTensorData<int32_t>(bias), GetTensorDims(bias), \ + data->output_multiplier, data->output_shift, \ + data->output_activation_min, data->output_activation_max, \ + GetTensorData<int16_t>(output), GetTensorDims(output), \ + GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context) + if (kernel_type == kReference) { + TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops); + } else { + TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops); + } +#undef TF_LITE_SHUFFLED_FULLY_CONNECTED + + return kTfLiteOk; +} + +template <KernelType kernel_type> TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node, TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input, const TfLiteTensor* filter, @@ -352,8 +395,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { return EvalFloat<kernel_type>(context, node, params, data, input, filter, bias, output); case kTfLiteUInt8: - return EvalQuantized<kernel_type>(context, node, params, data, input, - filter, bias, output); + if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) { + TfLiteTensor* shuffled_input_workspace = + GetOutput(context, node, kShuffledInputWorkspaceTensor); + return EvalShuffledQuantized<kernel_type>(context, node, params, data, + input, filter, bias, output, + shuffled_input_workspace); + } else if (params->weights_format == + kTfLiteFullyConnectedWeightsFormatDefault) { + return EvalQuantized<kernel_type>(context, node, params, data, input, + filter, bias, output); + } else { + context->ReportError(context, + "Unhandled fully-connected weights format"); + return kTfLiteError; + } default: context->ReportError(context, "Type %d not currently supported.", filter->type); |