diff options
author | 2018-04-17 11:53:29 -0700 | |
---|---|---|
committer | 2018-04-17 11:55:26 -0700 | |
commit | d7b6cb66c0fc346cf55020042931c07208713c60 (patch) | |
tree | 9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc | |
parent | 1192c1662c5c98f55805450b4619ac2bc9c6908c (diff) |
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc | 307 |
1 files changed, 307 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc new file mode 100644 index 0000000000..0bce183c18 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -0,0 +1,307 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void ChangeArrayDataType(GraphTransformation* transformation, Array* array, + ArrayDataType new_data_type, + const MinMax* new_minmax) { + // Ensure the array ends up in the new type (if it hasn't yet been quantized). + array->final_data_type = new_data_type; + + if (array->minmax && array->quantization_params) { + // The array is already quantized and has min/max info. + // As we are changing the data type we need to fix up the existing min/max + // to the new data type range. + + double old_quantized_min, old_quantized_max; + CHECK(GetQuantizedDataTypeNumericalRange( + array->data_type, &old_quantized_min, &old_quantized_max)) + << "Existing data type is not quantized: " + << ArrayDataTypeName(array->data_type); + double new_quantized_min, new_quantized_max; + CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min, + &new_quantized_max)) + << "New data type is not quantized: " + << ArrayDataTypeName(new_data_type); + + // Compute new minmax values. + double min = (old_quantized_min - array->quantization_params->zero_point) * + array->quantization_params->scale; + double max = + (old_quantized_max + 1 - array->quantization_params->zero_point) * + array->quantization_params->scale; + max = max - 1.0 / (new_quantized_max + 1); + + auto& array_minmax = array->GetOrCreateMinMax(); + transformation->AddMessageF( + "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min, + array_minmax.max, ArrayDataTypeName(array->data_type), min, max, + ArrayDataTypeName(new_data_type)); + + array_minmax.min = min; + array_minmax.max = max; + GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>( + array_minmax, array->quantization_params.get()); + + // Directly change the type as the array was already quantized. + array->data_type = new_data_type; + } else { + // Array has not yet been quantized so we can just set the final data type + // and assign the new min/max value (if provided). + CHECK(!array->quantization_params); + + if (!array->minmax && new_minmax) { + transformation->AddMessageF("Forcing new minmax to %g,%g (%s)", + new_minmax->min, new_minmax->max, + ArrayDataTypeName(new_data_type)); + auto& array_minmax = array->GetOrCreateMinMax(); + array_minmax.min = new_minmax->min; + array_minmax.max = new_minmax->max; + } + } +} + +// Returns true if the op blocks our backward recursive data type propagation. +bool DoesOpBlockBackwardPropagation(const Operator& op) { + switch (op.type) { + case OperatorType::kConcatenation: + case OperatorType::kTensorFlowConcat: + case OperatorType::kTensorFlowConcatV2: + // Concat shouldn't block propagation, but we do expect that all inputs + // have the same range. + return false; + case OperatorType::kDequantize: + // Dequantize ops are inserted between the value we care about and the + // FakeQuant so make sure we move across them. + case OperatorType::kGather: + // Gathers need their parameters changed to the appropriate data type. + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + // Reshapes and transposes don't change values. + return false; + default: + return true; + } +} + +// Returns true if the input of an op blocks our backward recursive data type +// propagation. +bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { + switch (op.type) { + case OperatorType::kGather: + // Ignore gather indices. + return input_index != 0; + break; + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + // Ignore reshape/transpose shapes/dimensions. + return input_index != 0; + default: + return false; + } +} + +// Propagates the data type up into the input arrays if they are model inputs +// that may need their type changed. May act recursively if the inputs are +// produced by ops that we can move over (such as Dequantize). +bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation, + Model* model, Operator* op, + ArrayDataType new_data_type, + const MinMax& new_minmax) { + bool did_change = false; + for (int input_index = 0; input_index < op->inputs.size(); ++input_index) { + const auto& input = op->inputs[input_index]; + auto& input_array = model->GetArray(input); + if (input_array.final_data_type == new_data_type) { + // Final data type is already - skip. + continue; + } + + // Prevent moving into constant param args that we don't want to modify. + if (DoesOpInputBlockBackwardPropagation(*op, input_index)) { + continue; + } + + if (input_array.final_data_type != new_data_type) { + transformation->AddMessageF( + "Adjusting input final data type of array %s from %s to %s", input, + ArrayDataTypeName(input_array.final_data_type), + ArrayDataTypeName(new_data_type)); + did_change = true; + ChangeArrayDataType(transformation, &input_array, new_data_type, + &new_minmax); + + // Walk up into all ops producing the inputs to this op. + for (auto& producing_op : model->operators) { + if (!DoesOpBlockBackwardPropagation(*producing_op)) { + for (const auto& output : producing_op->outputs) { + if (input == output) { + did_change |= RecursivelyBackwardPropagateDataType( + transformation, model, producing_op.get(), new_data_type, + new_minmax); + } + } + } + } + } + } + return did_change; +} + +// Returns true if the op blocks our forward recursive data type propagation. +bool DoesOpBlockForwardPropagation(const Operator& op) { + switch (op.type) { + case OperatorType::kFakeQuant: + // Always stop at another FakeQuant, as it will likely have different + // parameters. + return true; + default: + return false; + } +} + +// Recurses down the graph setting the data type of all arrays until an operator +// that blocks propagation (like another FakeQuant) or a final_data_type is +// already specified. +bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, + Model* model, Operator* op, + ArrayDataType new_data_type) { + bool did_change = false; + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (output_array.final_data_type == new_data_type) { + // Final data type is already - skip. + continue; + } + + if (output_array.final_data_type == ArrayDataType::kNone || + output_array.final_data_type != new_data_type) { + transformation->AddMessageF( + "Adjusting output final data type of array %s from %s to %s", output, + ArrayDataTypeName(output_array.final_data_type), + ArrayDataTypeName(new_data_type)); + did_change = true; + ChangeArrayDataType(transformation, &output_array, new_data_type, + nullptr); + + // Walk down into all ops consuming the output of this op. + for (auto& consuming_op : model->operators) { + if (!DoesOpBlockForwardPropagation(*consuming_op)) { + for (const auto& input : consuming_op->inputs) { + if (input == output) { + did_change |= RecursivelyForwardPropagateDataType( + transformation, model, consuming_op.get(), new_data_type); + } + } + } + } + } + } + return did_change; +} + +} // namespace + +// Propagates the num_bits on a FakeQuant operator into the final data types +// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know +// the output must be int16 and assume all inputs up until the preceding op are +// also 16. +// +// This can be thought of as a bidirectional flood-fill of the num_bits implied +// final_data_type that terminates at other FakeQuant ops (and a few others as +// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation). +// Once all FakeQuant ops have been visted the arrays should all have +// appropriate final_data_types if the source graph was annotated with the +// proper FakeQuant ops. +// +// Annotating a graph requires following a few hard rules: +// - every input MUST have a FakeQuant immediately following it +// - every output MUST have a FakeQuant immediately preceding it +// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant +// immediately following it +// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc) +// MUST have FakeQuants between them and the consuming op +// Additional FakeQuants may be used if desired, especially in areas that may +// suffer from large precision changes - such as between a Softmax and a +// FullyConnected. Only by validating accuracy differences between float +// inference with the FakeQuant ops simulating quantization and the actually +// quantized graph can you be sure the appropriate FakeQuant ops are present. +// +// You can tell if you're missing some FakeQuants by looking for warnings from +// quantize.cc about minmax ranges being determined by the contents of constant +// arrays. This will almost never produce functional models during inference. +// +// As this op may change the data types and ranges of input and output arrays +// downstream tools must also be sure to parse the output model flags to get the +// post-Transform values that may have changed due to this transformation. +// +// This isn't a GraphTransformation in the traditional respect as it affects ops +// outside of the one under transformation. This is primarily so that we can +// utilize the graph traversal and repeated pass system underlying the +// transformation system to exhaustively find all FakeQuant ops. It also gets us +// nice logging and integration with the graphviz video dumping mode. +// In general you should not copy this style of transformation and stick to +// local-only changes as seen in the other transformations. +bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast<FakeQuantOperator*>(op); + + ArrayDataType quantized_data_type = ArrayDataType::kNone; + if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, + &quantized_data_type)) { + AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring", + LogName(*op), fakequant_op->num_bits); + return false; + } + const auto& final_minmax = *fakequant_op->minmax; + + AddMessageF( + "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s", + LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max, + ArrayDataTypeName(quantized_data_type)); + + bool did_change = false; + + // Propagate the FakeQuant information backward up the graph. + // This will possibly adjust input arrays or constant types (like Gather). + did_change |= RecursivelyBackwardPropagateDataType( + this, model, op, quantized_data_type, final_minmax); + + // Propagate the FakeQuant information forward down the graph. + // This will possibly adjust output arrays. + did_change |= + RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type); + + return did_change; +} + +} // namespace toco |