diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-07 08:52:39 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-07 08:56:57 -0800 |
commit | 2b211b681ac6264c61372d10c496e234bf2eda9b (patch) | |
tree | 009cef946edbcb2ff4c9ba7dc6f74bf0ea802b4a /tensorflow/contrib/lite/kernels/dequantize.cc | |
parent | fbc2b857e45c4fe8fcd2ce016b3bb63ea9b9f924 (diff) |
Add support for the "DEQUANTIZE" op. This cover only ops that are generated by TOCO
in order to handle UINT8 input to floating-point models.
PiperOrigin-RevId: 188182372
Diffstat (limited to 'tensorflow/contrib/lite/kernels/dequantize.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/dequantize.cc | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc new file mode 100644 index 0000000000..e685f2465f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <string.h> +#include <vector> + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace dequantize { + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpContext op_context(context, node); + + TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8); + + op_context.output->type = kTfLiteFloat32; + return context->ResizeTensor(context, op_context.output, + TfLiteIntArrayCopy(op_context.input->dims)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + auto zero_point = op_context.input->params.zero_point; + auto scale = op_context.input->params.scale; + + optimized_ops::Dequantize(GetTensorData<uint8_t>(op_context.input), + GetTensorDims(op_context.input), zero_point, scale, + GetTensorData<float>(op_context.output), + GetTensorDims(op_context.output)); + return kTfLiteOk; +} + +} // namespace dequantize + +TfLiteRegistration* Register_DEQUANTIZE_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare, + dequantize::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEQUANTIZE() { return Register_DEQUANTIZE_OPT(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite |