diff options
author | 2018-03-20 11:27:54 -0700 | |
---|---|---|
committer | 2018-03-20 11:32:51 -0700 | |
commit | 4e5900eb874668e569cfa1b75c463a9f0b15738f (patch) | |
tree | 0394cf49e8e11eb6326c3eb7f528242ae14cac18 /tensorflow/contrib/lite/kernels/sub.cc | |
parent | beaf17d4b2b2e79e97b08b0382b302771ae6081e (diff) |
The Quantized BroadcastSub portion of #17123
PiperOrigin-RevId: 189776376
Diffstat (limited to 'tensorflow/contrib/lite/kernels/sub.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/sub.cc | 56 |
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc index c15a7a50a4..66b06aeaec 100644 --- a/tensorflow/contrib/lite/kernels/sub.cc +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -107,6 +107,59 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, } template <KernelType kernel_type> +void EvalQuantized(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { + auto input1_offset = -input1->params.zero_point; + auto input2_offset = -input2->params.zero_point; + auto output_offset = output->params.zero_point; + const int left_shift = 20; + const double twice_max_input_scale = + 2 * std::max(input1->params.scale, input2->params.scale); + const double real_input1_multiplier = + input1->params.scale / twice_max_input_scale; + const double real_input2_multiplier = + input2->params.scale / twice_max_input_scale; + const double real_output_multiplier = + twice_max_input_scale / ((1 << left_shift) * output->params.scale); + + int32 input1_multiplier; + int input1_shift; + QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier, + &input1_shift); + int32 input2_multiplier; + int input2_shift; + QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier, + &input2_shift); + int32 output_multiplier; + int output_shift; + QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier, + &output_shift); + + int32 output_activation_min, output_activation_max; + CalculateActivationRangeUint8(params->activation, output, + &output_activation_min, &output_activation_max); + +#define TF_LITE_SUB(type, opname) \ + type::opname(left_shift, GetTensorData<uint8_t>(input1), \ + GetTensorDims(input1), input1_offset, input1_multiplier, \ + input1_shift, GetTensorData<uint8_t>(input2), \ + GetTensorDims(input2), input2_offset, input2_multiplier, \ + input2_shift, output_offset, output_multiplier, output_shift, \ + output_activation_min, output_activation_max, \ + GetTensorData<uint8_t>(output), GetTensorDims(output)); + // The quantized version of Sub doesn't support activations, so we + // always use BroadcastSub. + if (kernel_type == kReference) { + TF_LITE_SUB(reference_ops, BroadcastSub); + } else { + TF_LITE_SUB(optimized_ops, BroadcastSub); + } +#undef TF_LITE_SUB +} + +template <KernelType kernel_type> TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data); OpData* data = reinterpret_cast<OpData*>(node->user_data); @@ -117,6 +170,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { if (output->type == kTfLiteFloat32) { EvalFloat<kernel_type>(context, node, params, data, input1, input2, output); + } else if (output->type == kTfLiteUInt8) { + EvalQuantized<kernel_type>(context, node, params, data, input1, input2, + output); } else { context->ReportError(context, "Inputs and outputs not all float types."); return kTfLiteError; |