aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 11:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 11:55:26 -0700
commitd7b6cb66c0fc346cf55020042931c07208713c60 (patch)
tree9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
parent1192c1662c5c98f55805450b4619ac2bc9c6908c (diff)
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc307
1 files changed, 307 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
new file mode 100644
index 0000000000..0bce183c18
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -0,0 +1,307 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
+ ArrayDataType new_data_type,
+ const MinMax* new_minmax) {
+ // Ensure the array ends up in the new type (if it hasn't yet been quantized).
+ array->final_data_type = new_data_type;
+
+ if (array->minmax && array->quantization_params) {
+ // The array is already quantized and has min/max info.
+ // As we are changing the data type we need to fix up the existing min/max
+ // to the new data type range.
+
+ double old_quantized_min, old_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(
+ array->data_type, &old_quantized_min, &old_quantized_max))
+ << "Existing data type is not quantized: "
+ << ArrayDataTypeName(array->data_type);
+ double new_quantized_min, new_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min,
+ &new_quantized_max))
+ << "New data type is not quantized: "
+ << ArrayDataTypeName(new_data_type);
+
+ // Compute new minmax values.
+ double min = (old_quantized_min - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ double max =
+ (old_quantized_max + 1 - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ max = max - 1.0 / (new_quantized_max + 1);
+
+ auto& array_minmax = array->GetOrCreateMinMax();
+ transformation->AddMessageF(
+ "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
+ array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
+ ArrayDataTypeName(new_data_type));
+
+ array_minmax.min = min;
+ array_minmax.max = max;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ array_minmax, array->quantization_params.get());
+
+ // Directly change the type as the array was already quantized.
+ array->data_type = new_data_type;
+ } else {
+ // Array has not yet been quantized so we can just set the final data type
+ // and assign the new min/max value (if provided).
+ CHECK(!array->quantization_params);
+
+ if (!array->minmax && new_minmax) {
+ transformation->AddMessageF("Forcing new minmax to %g,%g (%s)",
+ new_minmax->min, new_minmax->max,
+ ArrayDataTypeName(new_data_type));
+ auto& array_minmax = array->GetOrCreateMinMax();
+ array_minmax.min = new_minmax->min;
+ array_minmax.max = new_minmax->max;
+ }
+ }
+}
+
+// Returns true if the op blocks our backward recursive data type propagation.
+bool DoesOpBlockBackwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kConcatenation:
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Concat shouldn't block propagation, but we do expect that all inputs
+ // have the same range.
+ return false;
+ case OperatorType::kDequantize:
+ // Dequantize ops are inserted between the value we care about and the
+ // FakeQuant so make sure we move across them.
+ case OperatorType::kGather:
+ // Gathers need their parameters changed to the appropriate data type.
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Reshapes and transposes don't change values.
+ return false;
+ default:
+ return true;
+ }
+}
+
+// Returns true if the input of an op blocks our backward recursive data type
+// propagation.
+bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
+ switch (op.type) {
+ case OperatorType::kGather:
+ // Ignore gather indices.
+ return input_index != 0;
+ break;
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Ignore reshape/transpose shapes/dimensions.
+ return input_index != 0;
+ default:
+ return false;
+ }
+}
+
+// Propagates the data type up into the input arrays if they are model inputs
+// that may need their type changed. May act recursively if the inputs are
+// produced by ops that we can move over (such as Dequantize).
+bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type,
+ const MinMax& new_minmax) {
+ bool did_change = false;
+ for (int input_index = 0; input_index < op->inputs.size(); ++input_index) {
+ const auto& input = op->inputs[input_index];
+ auto& input_array = model->GetArray(input);
+ if (input_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ // Prevent moving into constant param args that we don't want to modify.
+ if (DoesOpInputBlockBackwardPropagation(*op, input_index)) {
+ continue;
+ }
+
+ if (input_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting input final data type of array %s from %s to %s", input,
+ ArrayDataTypeName(input_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &input_array, new_data_type,
+ &new_minmax);
+
+ // Walk up into all ops producing the inputs to this op.
+ for (auto& producing_op : model->operators) {
+ if (!DoesOpBlockBackwardPropagation(*producing_op)) {
+ for (const auto& output : producing_op->outputs) {
+ if (input == output) {
+ did_change |= RecursivelyBackwardPropagateDataType(
+ transformation, model, producing_op.get(), new_data_type,
+ new_minmax);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+// Returns true if the op blocks our forward recursive data type propagation.
+bool DoesOpBlockForwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kFakeQuant:
+ // Always stop at another FakeQuant, as it will likely have different
+ // parameters.
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Recurses down the graph setting the data type of all arrays until an operator
+// that blocks propagation (like another FakeQuant) or a final_data_type is
+// already specified.
+bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type) {
+ bool did_change = false;
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (output_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ if (output_array.final_data_type == ArrayDataType::kNone ||
+ output_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting output final data type of array %s from %s to %s", output,
+ ArrayDataTypeName(output_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &output_array, new_data_type,
+ nullptr);
+
+ // Walk down into all ops consuming the output of this op.
+ for (auto& consuming_op : model->operators) {
+ if (!DoesOpBlockForwardPropagation(*consuming_op)) {
+ for (const auto& input : consuming_op->inputs) {
+ if (input == output) {
+ did_change |= RecursivelyForwardPropagateDataType(
+ transformation, model, consuming_op.get(), new_data_type);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+} // namespace
+
+// Propagates the num_bits on a FakeQuant operator into the final data types
+// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know
+// the output must be int16 and assume all inputs up until the preceding op are
+// also 16.
+//
+// This can be thought of as a bidirectional flood-fill of the num_bits implied
+// final_data_type that terminates at other FakeQuant ops (and a few others as
+// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation).
+// Once all FakeQuant ops have been visted the arrays should all have
+// appropriate final_data_types if the source graph was annotated with the
+// proper FakeQuant ops.
+//
+// Annotating a graph requires following a few hard rules:
+// - every input MUST have a FakeQuant immediately following it
+// - every output MUST have a FakeQuant immediately preceding it
+// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant
+// immediately following it
+// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc)
+// MUST have FakeQuants between them and the consuming op
+// Additional FakeQuants may be used if desired, especially in areas that may
+// suffer from large precision changes - such as between a Softmax and a
+// FullyConnected. Only by validating accuracy differences between float
+// inference with the FakeQuant ops simulating quantization and the actually
+// quantized graph can you be sure the appropriate FakeQuant ops are present.
+//
+// You can tell if you're missing some FakeQuants by looking for warnings from
+// quantize.cc about minmax ranges being determined by the contents of constant
+// arrays. This will almost never produce functional models during inference.
+//
+// As this op may change the data types and ranges of input and output arrays
+// downstream tools must also be sure to parse the output model flags to get the
+// post-Transform values that may have changed due to this transformation.
+//
+// This isn't a GraphTransformation in the traditional respect as it affects ops
+// outside of the one under transformation. This is primarily so that we can
+// utilize the graph traversal and repeated pass system underlying the
+// transformation system to exhaustively find all FakeQuant ops. It also gets us
+// nice logging and integration with the graphviz video dumping mode.
+// In general you should not copy this style of transformation and stick to
+// local-only changes as seen in the other transformations.
+bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
+
+ ArrayDataType quantized_data_type = ArrayDataType::kNone;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
+ LogName(*op), fakequant_op->num_bits);
+ return false;
+ }
+ const auto& final_minmax = *fakequant_op->minmax;
+
+ AddMessageF(
+ "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s",
+ LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max,
+ ArrayDataTypeName(quantized_data_type));
+
+ bool did_change = false;
+
+ // Propagate the FakeQuant information backward up the graph.
+ // This will possibly adjust input arrays or constant types (like Gather).
+ did_change |= RecursivelyBackwardPropagateDataType(
+ this, model, op, quantized_data_type, final_minmax);
+
+ // Propagate the FakeQuant information forward down the graph.
+ // This will possibly adjust output arrays.
+ did_change |=
+ RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
+
+ return did_change;
+}
+
+} // namespace toco