aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/mul.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-20 14:09:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 14:12:16 -0700
commit7ec196c4a28352008d0c947e4a0f0bb404953f98 (patch)
tree5ac502ac0aed7020c8a5c899e01c6a3bd45400e3 /tensorflow/contrib/lite/kernels/mul.cc
parentc1ff1164e30186d847f7d4f9e9ce5d40936a2c1c (diff)
16-bit quantized Mul support in TFLite interpreter
PiperOrigin-RevId: 201413223
Diffstat (limited to 'tensorflow/contrib/lite/kernels/mul.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/mul.cc118
1 files changed, 80 insertions, 38 deletions
diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc
index b69a221447..9e01b73c49 100644
--- a/tensorflow/contrib/lite/kernels/mul.cc
+++ b/tensorflow/contrib/lite/kernels/mul.cc
@@ -39,6 +39,14 @@ constexpr int kOutputTensor = 0;
struct OpData {
bool requires_broadcast;
+
+ // Parameters used in the quantized paths where the output is 8bit
+ int32 output_activation_min;
+ int32 output_activation_max;
+
+ // Parameters used in all quantized paths
+ int32_t output_multiplier;
+ int output_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -52,6 +60,7 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
@@ -62,7 +71,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
- output->type = input2->type;
data->requires_broadcast = !HaveSameShapes(input1, input2);
@@ -74,6 +82,20 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
output_size = TfLiteIntArrayCopy(input1->dims);
}
+ if (output->type == kTfLiteUInt8) {
+ CalculateActivationRangeUint8(params->activation, output,
+ &data->output_activation_min,
+ &data->output_activation_max);
+ }
+
+ if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ double real_multiplier =
+ input1->params.scale * input2->params.scale / output->params.scale;
+ QuantizeMultiplierSmallerThanOneExp(
+ real_multiplier, &data->output_multiplier, &data->output_shift);
+ data->output_shift *= -1;
+ }
+
return context->ResizeTensor(context, output, output_size);
}
@@ -107,42 +129,60 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
-void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
- TfLiteMulParams* params, const OpData* data,
- const TfLiteTensor* input1, const TfLiteTensor* input2,
- TfLiteTensor* output) {
- auto input1_offset = -input1->params.zero_point;
- auto input2_offset = -input2->params.zero_point;
- auto output_offset = output->params.zero_point;
-
- int32_t output_multiplier;
- int output_shift;
-
- double real_multiplier =
- input1->params.scale * input2->params.scale / output->params.scale;
- QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
- &output_shift);
- output_shift *= -1;
-
- int32 output_activation_min, output_activation_max;
- CalculateActivationRangeUint8(params->activation, output,
- &output_activation_min, &output_activation_max);
-
-#define TF_LITE_MUL(type, opname) \
- type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
- input1_offset, GetTensorData<uint8_t>(input2), \
- GetTensorDims(input2), input2_offset, output_offset, \
- output_multiplier, output_shift, output_activation_min, \
- output_activation_max, GetTensorData<uint8_t>(output), \
+TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteMulParams* params, const OpData* data,
+ const TfLiteTensor* input1,
+ const TfLiteTensor* input2, TfLiteTensor* output) {
+ if (input1->type == kTfLiteUInt8 && input2->type == kTfLiteUInt8 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<uint8_t>(input1), GetTensorDims(input1), \
+ -input1->params.zero_point, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), -input2->params.zero_point, \
+ output->params.zero_point, data->output_multiplier, \
+ data->output_shift, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
GetTensorDims(output));
- // The quantized version of Mul doesn't support activations, so we
- // always use BroadcastMul.
- if (kernel_type == kReference) {
- TF_LITE_MUL(reference_ops, BroadcastMul);
+ // The quantized version of Mul doesn't support activations, so we
+ // always use BroadcastMul.
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, BroadcastMul);
+ } else {
+ TF_LITE_MUL(optimized_ops, BroadcastMul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteInt16) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ GetTensorData<int16_t>(output), GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
+ } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
+ output->type == kTfLiteUInt8) {
+#define TF_LITE_MUL(type, opname) \
+ type::opname(GetTensorData<int16_t>(input1), GetTensorDims(input1), \
+ GetTensorData<int16_t>(input2), GetTensorDims(input2), \
+ output->params.zero_point, data->output_activation_min, \
+ data->output_activation_max, GetTensorData<uint8_t>(output), \
+ GetTensorDims(output));
+ if (kernel_type == kReference) {
+ TF_LITE_MUL(reference_ops, Mul);
+ } else {
+ TF_LITE_MUL(optimized_ops, Mul);
+ }
+#undef TF_LITE_MUL
} else {
- TF_LITE_MUL(optimized_ops, BroadcastMul);
+ context->ReportError(
+ context, "Unsupported combination of input and output types in Mul.");
+ return kTfLiteError;
}
-#undef TF_LITE_MUL
+ return kTfLiteOk;
}
template <KernelType kernel_type>
@@ -156,12 +196,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
- } else if (output->type == kTfLiteUInt8) {
- EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
- output);
+ } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
+ TF_LITE_ENSURE_OK(
+ context, EvalQuantized<kernel_type>(context, node, params, data, input1,
+ input2, output));
} else {
context->ReportError(
- context, "Mul only supports FLOAT32 and quantized UINT8 now, got %d.",
+ context,
+ "Mul only supports FLOAT32 and quantized UINT8 and INT16 now, got %d.",
output->type);
return kTfLiteError;
}