aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/sub.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-20 11:27:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-20 11:32:51 -0700
commit4e5900eb874668e569cfa1b75c463a9f0b15738f (patch)
tree0394cf49e8e11eb6326c3eb7f528242ae14cac18 /tensorflow/contrib/lite/kernels/sub.cc
parentbeaf17d4b2b2e79e97b08b0382b302771ae6081e (diff)
The Quantized BroadcastSub portion of #17123
PiperOrigin-RevId: 189776376
Diffstat (limited to 'tensorflow/contrib/lite/kernels/sub.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/sub.cc56
1 files changed, 56 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc
index c15a7a50a4..66b06aeaec 100644
--- a/tensorflow/contrib/lite/kernels/sub.cc
+++ b/tensorflow/contrib/lite/kernels/sub.cc
@@ -107,6 +107,59 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node,
}
template <KernelType kernel_type>
+void EvalQuantized(TfLiteContext* context, TfLiteNode* node,
+ TfLiteSubParams* params, const OpData* data,
+ TfLiteTensor* input1, TfLiteTensor* input2,
+ TfLiteTensor* output) {
+ auto input1_offset = -input1->params.zero_point;
+ auto input2_offset = -input2->params.zero_point;
+ auto output_offset = output->params.zero_point;
+ const int left_shift = 20;
+ const double twice_max_input_scale =
+ 2 * std::max(input1->params.scale, input2->params.scale);
+ const double real_input1_multiplier =
+ input1->params.scale / twice_max_input_scale;
+ const double real_input2_multiplier =
+ input2->params.scale / twice_max_input_scale;
+ const double real_output_multiplier =
+ twice_max_input_scale / ((1 << left_shift) * output->params.scale);
+
+ int32 input1_multiplier;
+ int input1_shift;
+ QuantizeMultiplierSmallerThanOne(real_input1_multiplier, &input1_multiplier,
+ &input1_shift);
+ int32 input2_multiplier;
+ int input2_shift;
+ QuantizeMultiplierSmallerThanOne(real_input2_multiplier, &input2_multiplier,
+ &input2_shift);
+ int32 output_multiplier;
+ int output_shift;
+ QuantizeMultiplierSmallerThanOne(real_output_multiplier, &output_multiplier,
+ &output_shift);
+
+ int32 output_activation_min, output_activation_max;
+ CalculateActivationRangeUint8(params->activation, output,
+ &output_activation_min, &output_activation_max);
+
+#define TF_LITE_SUB(type, opname) \
+ type::opname(left_shift, GetTensorData<uint8_t>(input1), \
+ GetTensorDims(input1), input1_offset, input1_multiplier, \
+ input1_shift, GetTensorData<uint8_t>(input2), \
+ GetTensorDims(input2), input2_offset, input2_multiplier, \
+ input2_shift, output_offset, output_multiplier, output_shift, \
+ output_activation_min, output_activation_max, \
+ GetTensorData<uint8_t>(output), GetTensorDims(output));
+ // The quantized version of Sub doesn't support activations, so we
+ // always use BroadcastSub.
+ if (kernel_type == kReference) {
+ TF_LITE_SUB(reference_ops, BroadcastSub);
+ } else {
+ TF_LITE_SUB(optimized_ops, BroadcastSub);
+ }
+#undef TF_LITE_SUB
+}
+
+template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteSubParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
@@ -117,6 +170,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
if (output->type == kTfLiteFloat32) {
EvalFloat<kernel_type>(context, node, params, data, input1, input2, output);
+ } else if (output->type == kTfLiteUInt8) {
+ EvalQuantized<kernel_type>(context, node, params, data, input1, input2,
+ output);
} else {
context->ReportError(context, "Inputs and outputs not all float types.");
return kTfLiteError;