aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/fully_connected.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-26 14:16:37 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit110ddc2103d7c86084ff52994998575113862542 (patch)
tree503f944630e7ea4d2cd9ca5fbca4621c7f555db6 /tensorflow/contrib/lite/kernels/fully_connected.cc
parent92221c68cdcf27607969089e5b6c06fdeeae8ae8 (diff)
Un-fused quantized Babelfish LSTM cell support in TFLite
including support for shuffled-weights fully-connected op. PiperOrigin-RevId: 202192299
Diffstat (limited to 'tensorflow/contrib/lite/kernels/fully_connected.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/fully_connected.cc69
1 files changed, 63 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/fully_connected.cc b/tensorflow/contrib/lite/kernels/fully_connected.cc
index f6fc0f5b6a..b40294709b 100644
--- a/tensorflow/contrib/lite/kernels/fully_connected.cc
+++ b/tensorflow/contrib/lite/kernels/fully_connected.cc
@@ -63,6 +63,7 @@ constexpr int kInputTensor = 0;
constexpr int kWeightsTensor = 1;
constexpr int kBiasTensor = 2;
constexpr int kOutputTensor = 0;
+constexpr int kShuffledInputWorkspaceTensor = 1;
constexpr int kScratchBufferTensor = 1;
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -87,7 +88,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 3);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
+ // Shuffled formats need a workspace to store the shuffled input activations.
+ const int expected_outputs_count =
+ params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
+ : 2;
+ TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* filter = GetInput(context, node, kWeightsTensor);
@@ -121,9 +126,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
QuantizeMultiplierSmallerThanOneExp(
real_multiplier, &data->output_multiplier, &data->output_shift);
data->output_shift *= -1;
- CalculateActivationRangeUint8(params->activation, output,
- &data->output_activation_min,
- &data->output_activation_max);
+ TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
+ context, params->activation, output, &data->output_activation_min,
+ &data->output_activation_max));
}
// If we have to perform on-the-fly quantization (with quantized weights and
@@ -309,6 +314,44 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteFullyConnectedParams* params,
+ OpData* data, const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias,
+ TfLiteTensor* output,
+ TfLiteTensor* shuffled_input_workspace) {
+ gemmlowp::GemmContext* gemm_context = gemm_support::GetFromContext(context);
+
+ // TODO(b/110697972) decide more consistently if / how / where we want
+ // to perform this kind of runtime data type checks.
+ if (input->type != kTfLiteUInt8 || filter->type != kTfLiteUInt8 ||
+ bias->type != kTfLiteInt32 || output->type != kTfLiteInt16 ||
+ shuffled_input_workspace->type != kTfLiteUInt8) {
+ context->ReportError(context, "Unexpected data type");
+ return kTfLiteError;
+ }
+
+#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
+ type::ShuffledFullyConnected( \
+ GetTensorData<uint8_t>(input), GetTensorDims(input), \
+ GetTensorData<uint8_t>(filter), GetTensorDims(filter), \
+ GetTensorData<int32_t>(bias), GetTensorDims(bias), \
+ data->output_multiplier, data->output_shift, \
+ data->output_activation_min, data->output_activation_max, \
+ GetTensorData<int16_t>(output), GetTensorDims(output), \
+ GetTensorData<uint8_t>(shuffled_input_workspace), gemm_context)
+ if (kernel_type == kReference) {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(reference_ops);
+ } else {
+ TF_LITE_SHUFFLED_FULLY_CONNECTED(optimized_ops);
+ }
+#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
+
+ return kTfLiteOk;
+}
+
+template <KernelType kernel_type>
TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
TfLiteFullyConnectedParams* params, OpData* data,
const TfLiteTensor* input, const TfLiteTensor* filter,
@@ -352,8 +395,22 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
return EvalFloat<kernel_type>(context, node, params, data, input, filter,
bias, output);
case kTfLiteUInt8:
- return EvalQuantized<kernel_type>(context, node, params, data, input,
- filter, bias, output);
+ if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
+ TfLiteTensor* shuffled_input_workspace =
+ GetOutput(context, node, kShuffledInputWorkspaceTensor);
+ return EvalShuffledQuantized<kernel_type>(context, node, params, data,
+ input, filter, bias, output,
+ shuffled_input_workspace);
+ } else if (params->weights_format ==
+ kTfLiteFullyConnectedWeightsFormatDefault) {
+ return EvalQuantized<kernel_type>(context, node, params, data, input,
+ filter, bias, output);
+ } else {
+ context->ReportError(context,
+ "Unhandled fully-connected weights format");
+ return kTfLiteError;
+ }
default:
context->ReportError(context, "Type %d not currently supported.",
filter->type);