aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc223
1 files changed, 223 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
new file mode 100644
index 0000000000..b89e3f5310
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType A>
+void DequantizeBuffer(Array* array) {
+ const auto old_data = array->GetBuffer<A>().data;
+ array->buffer = nullptr;
+ array->data_type = ArrayDataType::kFloat;
+ auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ new_data.resize(old_data.size());
+ const auto& qparams = array->GetQuantizationParams();
+ for (int i = 0; i < old_data.size(); i++) {
+ new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
+ }
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
+ Model* model, const string& array_name) {
+ for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
+ for (const auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model->operators.end();
+}
+
+void ClearArrayQuantizationParams(const string& array_name, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ CHECK(array->quantization_params);
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == array_name) {
+ auto& qparams = *array->quantization_params;
+ const double new_std_value = 1. / qparams.scale;
+ const double new_mean_value = qparams.zero_point;
+ if (input_array.has_std_value()) {
+ CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
+ } else {
+ input_array.set_std_value(new_std_value);
+ }
+ if (input_array.has_mean_value()) {
+ CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
+ } else {
+ input_array.set_mean_value(new_mean_value);
+ }
+ }
+ }
+ array->quantization_params = nullptr;
+}
+
+bool DequantizeArray(const string& array_name,
+ GraphTransformation* transformation, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ if (!array->quantization_params) {
+ return false;
+ }
+ transformation->AddMessageF("Dequantizing array: %s", array_name);
+
+ // Dequantize any buffer
+ if (array->buffer) {
+ if (array->data_type == ArrayDataType::kUint8) {
+ DequantizeBuffer<ArrayDataType::kUint8>(array);
+ } else if (array->data_type == ArrayDataType::kInt32) {
+ DequantizeBuffer<ArrayDataType::kInt32>(array);
+ } else {
+ LOG(FATAL) << "Unhandled data type";
+ }
+ CHECK(array->data_type == ArrayDataType::kFloat);
+ CHECK(array->buffer->type == ArrayDataType::kFloat);
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+ return true;
+ } else {
+ array->data_type = ArrayDataType::kFloat;
+ }
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+
+ if (array->buffer) {
+ return true;
+ }
+
+ auto* op_outputting_array = GetOpWithOutput(*model, array_name);
+ if (op_outputting_array) {
+ if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ return true;
+ }
+ }
+
+ // If there was no minmax info, we can return now. Indeed,
+ // the below only serves to create a FakeQuant node, but some arrays are
+ // quantized without MinMax (see the CHECK above) and that corresponds to
+ // places where a FakeQuant node is actually not wanted, because the
+ // quantization params are meant to be inferred in another way (e.g. bias
+ // vector for a Conv op, see their special-casing in quantize.cc).
+ if (!array->minmax) {
+ return true;
+ }
+
+ // Determine whether to insert a FakeQuant before or after
+ // this array.
+ bool must_insert_fakequant_before = false;
+ bool must_insert_fakequant_after = false;
+ if (IsInputArray(*model, array_name)) {
+ must_insert_fakequant_after = true;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (array_name == output_array) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ must_insert_fakequant_after = true;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
+
+ // Create and insert the FakeQuant node
+ auto* fakequant_op = new FakeQuantOperator;
+ model->operators.emplace(FindFirstOpWithInput(model, array_name),
+ fakequant_op);
+ const string& new_array_name = AvailableArrayName(*model, array_name);
+ auto& new_array = model->GetOrCreateArray(new_array_name);
+ new_array.data_type = ArrayDataType::kFloat;
+ new_array.copy_shape(array->shape());
+ new_array.GetOrCreateMinMax() = array->GetMinMax();
+ fakequant_op->minmax.reset(new MinMax);
+ *fakequant_op->minmax = array->GetMinMax();
+ if (must_insert_fakequant_before) {
+ for (const auto& op : model->operators) {
+ for (string& output : op->outputs) {
+ if (output == array_name) {
+ output = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {new_array_name};
+ fakequant_op->outputs = {array_name};
+ } else {
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == array_name) {
+ input = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {array_name};
+ fakequant_op->outputs = {new_array_name};
+ }
+ return true;
+}
+
+} // namespace
+
+bool Dequantize::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+
+ if (op->type == OperatorType::kDequantize) {
+ auto& input_array = model->GetArray(op->inputs[0]);
+ if (input_array.data_type == ArrayDataType::kFloat) {
+ return false;
+ }
+ if (input_array.final_data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ input_array.data_type = ArrayDataType::kFloat;
+ input_array.quantization_params = nullptr;
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.data_type = ArrayDataType::kFloat;
+ output_array.quantization_params = nullptr;
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+
+ std::vector<string> arrays;
+ for (const string& input : op->inputs) {
+ arrays.push_back(input);
+ }
+ for (const string& output : op->outputs) {
+ arrays.push_back(output);
+ }
+ bool changed = false;
+ for (const string& array : arrays) {
+ changed |= DequantizeArray(array, this, model);
+ }
+
+ return changed;
+}
+
+} // namespace toco