/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS #include "tensorflow/core/common_runtime/constant_folding.h" #include "tensorflow/core/common_runtime/threadpool_device.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/subgraph.h" #include "tensorflow/core/kernels/quantization_utils.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/command_line_flags.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { namespace graph_transforms { // Holds the information we need to translate from a float version of this op // into the quantized equivalent. struct QuantizedOpInfo { // The name of the float op. string float_name; // Which attributes to copy directly over. std::vector attrs_to_copy; // Extra data type attributes we need to set. std::vector> dtypes_to_set; // What depth of inputs the op can read in. DataType input_bit_depth; // The depth of the op's quantized outputs. DataType output_bit_depth; // Which inputs (e.g. shapes) aren't involved in the quantization process. std::set unquantized_inputs; // How the outputs are arranged, either // [input0, input1, min0, max0, min1, max1] for contiguous, or // [input0, input1, min0, min1, max0, max1] for separate. // The separate order is needed because it's the only way to specify unknown // numbers of inputs for ops like Concat. enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order; }; // Every op that has a quantized equivalent should be listed here, so that the // conversion process can transform them. const std::vector& GetQuantizedOpList() { static const std::vector op_list = { {"Add", {}, {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, DT_QUINT8, DT_QINT32, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"AvgPool", {"ksize", "strides", "padding"}, {{"T", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"BiasAdd", {}, {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}}, DT_QUINT8, DT_QINT32, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Concat", {"N"}, {{"T", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {0}, QuantizedOpInfo::SEPARATE_MIN_MAX}, {"Conv2D", {"strides", "padding"}, {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}}, DT_QUINT8, DT_QINT32, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"MatMul", {"transpose_a", "transpose_b"}, {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, DT_QUINT8, DT_QINT32, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"MaxPool", {"ksize", "strides", "padding"}, {{"T", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Mul", {}, {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}}, DT_QUINT8, DT_QINT32, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Relu", {}, {{"Tinput", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"ResizeBilinear", {"align_corners"}, {{"T", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {1}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Relu6", {}, {{"Tinput", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Reshape", {}, {{"T", DT_QUINT8}}, DT_QUINT8, DT_QUINT8, {1}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, }; return op_list; } namespace { // Replaces invalid characters in input names to get a unique node name. string UniqueNodeNameFromInput(const string& input_name) { string prefix; string node_name; string suffix; NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix); string result; if (prefix == "^") { result += "__hat__"; } result += node_name; if (!suffix.empty()) { result += "__port__" + suffix.substr(1, suffix.size() - 1); } return result; } // Pulls two float values from the named parameters, with a lot of checking. Status ExtractRangeFromParams(const TransformFuncContext& context, const string& min_name, const string& max_name, float* min_value, float* max_value, bool* has_range) { // See if we've been given quantized inputs with a known range. const bool has_min = (context.params.count(min_name) != 0); const bool has_max = (context.params.count(max_name) != 0); *has_range = (has_min || has_max); if (!*has_range) { return Status::OK(); } if (!has_min || !has_max) { return errors::InvalidArgument("You must pass both ", min_name, " and ", max_name, " into quantize_nodes"); } TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value)); TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value)); return Status::OK(); } } // namespace // Analyzes all the nodes in the graph to figure out which ones are duplicates // apart from their names. This commonly includes identical Const nodes, but can // also be simple operations that are repeated on multiple outputs of a // particular node. The complexity is managed using a hash function that avoids // the need for any O(n^2) algorithms when identifying duplicates. Status MergeDuplicateNodes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { // Make sure we can look up inputs and outputs quickly. std::set input_names(context.input_names.begin(), context.input_names.end()); std::set output_names(context.output_names.begin(), context.output_names.end()); GraphDef current_graph_def = input_graph_def; // Keep running the merging until no more duplicates are found. bool any_duplicates_found; do { any_duplicates_found = false; // First arrange all of the nodes by a hash of their contents. std::map> hashed_nodes; for (const NodeDef& node : current_graph_def.node()) { NodeDef nameless_node = node; // The name matters if it's being used as an input or output node, // otherwise ignore it when looking for duplicates. if (!input_names.count(node.name()) && !output_names.count(node.name())) { nameless_node.set_name(""); } const uint64 hash = HashNodeDef(nameless_node); hashed_nodes[hash].push_back(&node); } // If we have multiple nodes with the same hash, then we know they're // duplicates and can be removed, unless they're stateful. std::map inputs_to_rename; GraphDef merged_graph_def; for (const std::pair> hashed_node_info : hashed_nodes) { const std::vector& hash_node_list = hashed_node_info.second; for (int i = 0; i < hash_node_list.size(); ++i) { const NodeDef* current_node = hash_node_list[i]; const OpDef* op_def = nullptr; TF_RETURN_IF_ERROR( OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def)); const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0)); if (is_duplicate) { const string original_name = hash_node_list[0]->name(); inputs_to_rename[current_node->name() + ":*"] = original_name; any_duplicates_found = true; } else { NodeDef* new_node = merged_graph_def.mutable_node()->Add(); *new_node = *current_node; } } } // Update the graph so that any nodes that referred to removed inputs now // pull from the remaining duplicate. TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename, std::unordered_set(), ¤t_graph_def)); } while (any_duplicates_found); *output_graph_def = current_graph_def; return Status::OK(); } // Looks for the patterns that indicate there are two eight-bit ops feeding into // each other, separated by a conversion up to float and back again. These occur // during the initial conversion of ops to their quantized forms. Because we're // only looking at an individual op in that phase and don't know if its inputs // and outputs are eight-bit-capable, we start by converting the actual op into // quantized form, but add float conversions before and after. This pass gets // rid of those conversions if it turns out we do have adjacent ops capable of // eight-bit processing. Status RemoveRedundantQuantizations(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { std::set graph_outputs; for (const string& output_name : context.output_names) { graph_outputs.insert(NodeNameFromInput(output_name)); } std::map inputs_to_rename; GraphDef replaced_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, // clang-format off {"QuantizeV2", { {"Dequantize"}, {"Min"}, {"Max"}, } }, // clang-format on [&inputs_to_rename, &graph_outputs](const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { const NodeDef& quantize_node = match.node; const NodeDef& dequantize_node = match.inputs[0].node; inputs_to_rename[quantize_node.name() + ":0"] = dequantize_node.input(0); inputs_to_rename[quantize_node.name() + ":1"] = dequantize_node.input(1); inputs_to_rename[quantize_node.name() + ":2"] = dequantize_node.input(2); // Are other sub-graphs using the float intermediate result? If so, // preserve it, but the input renaming still rewires the eight-bit ops // so they don't go through float. if (output_nodes.count(dequantize_node.name()) || graph_outputs.count(dequantize_node.name())) { CopyOriginalMatch(match, new_nodes); } return Status::OK(); }, {true}, &replaced_graph_def)); return RenameNodeInputs(replaced_graph_def, inputs_to_rename, std::unordered_set(), output_graph_def); } // If the user has passed in the input_min and input_max args, then we need to // convert any input placeholders from float to eight bit, so quantized inputs // can be fed directly into the graph. Status QuantizePlaceholders(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { float input_min; float input_max; bool has_input_range; TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max", &input_min, &input_max, &has_input_range)); if (!has_input_range) { *output_graph_def = input_graph_def; return Status::OK(); } std::map inputs_to_rename_first_pass; std::map inputs_to_rename_second_pass; GraphDef placeholder_graph_def; placeholder_graph_def.Clear(); for (const NodeDef& node : input_graph_def.node()) { if (node.op() != "Placeholder") { *(placeholder_graph_def.mutable_node()->Add()) = node; } else { string namespace_prefix = node.name() + "_eightbit"; NodeDef quantized_placeholder; quantized_placeholder = node; SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder); *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder; NodeDef min_node; min_node.set_op("Const"); min_node.set_name(namespace_prefix + "/min"); SetNodeAttr("dtype", DT_FLOAT, &min_node); Tensor min_tensor(DT_FLOAT, {}); min_tensor.flat()(0) = input_min; SetNodeTensorAttr("value", min_tensor, &min_node); *(placeholder_graph_def.mutable_node()->Add()) = min_node; NodeDef max_node; max_node.set_op("Const"); max_node.set_name(namespace_prefix + "/max"); SetNodeAttr("dtype", DT_FLOAT, &max_node); Tensor max_tensor(DT_FLOAT, {}); max_tensor.flat()(0) = input_max; SetNodeTensorAttr("value", max_tensor, &max_node); *(placeholder_graph_def.mutable_node()->Add()) = max_node; const string rename_suffix = "__RENAMED_PLACEHOLDER__"; NodeDef dequantize_node; dequantize_node.set_op("Dequantize"); dequantize_node.set_name(namespace_prefix + "/dequantize"); SetNodeAttr("T", DT_QUINT8, &dequantize_node); SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); AddNodeInput(node.name() + rename_suffix, &dequantize_node); AddNodeInput(min_node.name(), &dequantize_node); AddNodeInput(max_node.name(), &dequantize_node); *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node; // First make sure that any internal references to the old placeholder // now point to the dequantize result. inputs_to_rename_first_pass[node.name()] = dequantize_node.name(); // Then fix up the dequantize op so that it really points to the // placeholder. inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name(); } } GraphDef first_pass_graph_def; TF_RETURN_IF_ERROR( RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass, std::unordered_set(), &first_pass_graph_def)); TF_RETURN_IF_ERROR( RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass, std::unordered_set(), output_graph_def)); return Status::OK(); } // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for // an activation layer. To use these during inference, this pass converts those // ops into Requantizes with the trained min/maxes as constant inputs. Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, // clang-format off {"FakeQuantWithMinMaxVars", { {"*"}, {"Const"}, {"Const"}, } }, // clang-format on [](const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { const NodeDef& fake_quant_node = match.node; const NodeDef& original_op_node = match.inputs[0].node; const NodeDef& fake_quant_min_node = match.inputs[1].node; const NodeDef& fake_quant_max_node = match.inputs[2].node; string namespace_prefix = fake_quant_node.name() + "_eightbit"; new_nodes->push_back(original_op_node); new_nodes->push_back(fake_quant_min_node); new_nodes->push_back(fake_quant_max_node); NodeDef quantize_node; quantize_node.set_op("QuantizeV2"); quantize_node.set_name(namespace_prefix + "/quantize"); SetNodeAttr("T", DT_QINT32, &quantize_node); SetNodeAttr("mode", "MIN_FIRST", &quantize_node); AddNodeInput(fake_quant_node.input(0), &quantize_node); AddNodeInput(fake_quant_min_node.name(), &quantize_node); AddNodeInput(fake_quant_max_node.name(), &quantize_node); new_nodes->push_back(quantize_node); NodeDef requantize_node; requantize_node.set_op("Requantize"); requantize_node.set_name(namespace_prefix + "/requantize"); SetNodeAttr("Tinput", DT_QINT32, &requantize_node); SetNodeAttr("out_type", DT_QUINT8, &requantize_node); AddNodeInput(quantize_node.name() + ":0", &requantize_node); AddNodeInput(quantize_node.name() + ":1", &requantize_node); AddNodeInput(quantize_node.name() + ":2", &requantize_node); AddNodeInput(fake_quant_min_node.name(), &requantize_node); AddNodeInput(fake_quant_max_node.name(), &requantize_node); new_nodes->push_back(requantize_node); // Convert the 8-bit result back into float for the final output. NodeDef dequantize_node; dequantize_node.set_op("Dequantize"); dequantize_node.set_name(fake_quant_node.name()); SetNodeAttr("T", DT_QUINT8, &dequantize_node); SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); AddNodeInput(requantize_node.name() + ":0", &dequantize_node); AddNodeInput(requantize_node.name() + ":1", &dequantize_node); AddNodeInput(requantize_node.name() + ":2", &dequantize_node); new_nodes->push_back(dequantize_node); return Status::OK(); }, {}, output_graph_def)); return Status::OK(); } // We always generate Requantize ops driven by dynamic RequantizationRange // calculations when we produce quantized ops like Conv2D or BiasAdd with // 32-bit results. If there were FakeQuant ops already for those activation // layers, then there will be a later Requantize op with constant min/max // inputs, which is preferable for fast inference. This pass looks for those // later Requantize ops, and replaces the dynamic version with them. Status MergeAdjacentRequantizes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( input_graph_def, // clang-format off {"Requantize", { {"QuantizeV2", { {"Dequantize", { {"Requantize", { {"*"}, {"*"}, {"*"}, {"RequantizationRange"}, {"RequantizationRange"}, } }, {"Requantize"}, {"Requantize"}, } }, {"Const"}, {"Const"}, }, }, {"QuantizeV2"}, {"QuantizeV2"}, {"Const"}, {"Const"}, } }, // clang-format on [](const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { const NodeDef& fake_requantize_node = match.node; const NodeDef& original_op_node = match.inputs[0].inputs[0].inputs[0].inputs[0].node; const NodeDef& fake_requantize_min_node = match.inputs[3].node; const NodeDef& fake_requantize_max_node = match.inputs[4].node; new_nodes->push_back(original_op_node); new_nodes->push_back(fake_requantize_min_node); new_nodes->push_back(fake_requantize_max_node); NodeDef requantize_node; requantize_node = fake_requantize_node; requantize_node.mutable_input()->Clear(); AddNodeInput(original_op_node.name() + ":0", &requantize_node); AddNodeInput(original_op_node.name() + ":1", &requantize_node); AddNodeInput(original_op_node.name() + ":2", &requantize_node); AddNodeInput(fake_requantize_min_node.name(), &requantize_node); AddNodeInput(fake_requantize_max_node.name(), &requantize_node); new_nodes->push_back(requantize_node); return Status::OK(); }, {}, output_graph_def)); return Status::OK(); } // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd // op that we want to apply the trained constant conversions to. This pass tries // to move FakeQuant ops up the input chain, so they're as close as possible to // the 32-bit conversion, and so can be easily merged into the automatic dynamic // Requantizes. Status HoistFakeQuants(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { GraphDef current_graph_def = input_graph_def; const int max_depth = 3; for (int depth = max_depth; depth > 0; --depth) { OpTypePattern pattern = {"*"}; for (int i = 0; i < depth; ++i) { pattern = {"*", {pattern}}; } pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}}; GraphDef hoisted_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( current_graph_def, pattern, [depth](const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { const NodeDef& fake_quant_node = match.node; const NodeDef& fake_quant_min_node = match.inputs[1].node; const NodeDef& fake_quant_max_node = match.inputs[2].node; std::vector linear_nodes; NodeMatch current_match = match; for (int i = 0; i <= depth; ++i) { linear_nodes.push_back(current_match.inputs[0].node); current_match = current_match.inputs[0]; } NodeDef new_fake_quant_node; new_fake_quant_node = fake_quant_node; new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted"); new_fake_quant_node.set_input( 0, linear_nodes[linear_nodes.size() - 2].input(0)); new_nodes->push_back(new_fake_quant_node); new_nodes->push_back(fake_quant_min_node); new_nodes->push_back(fake_quant_max_node); linear_nodes[linear_nodes.size() - 2].set_input( 0, new_fake_quant_node.name()); linear_nodes.front().set_name(fake_quant_node.name()); for (const NodeDef& linear_node : linear_nodes) { new_nodes->push_back(linear_node); } return Status::OK(); }, {}, &hoisted_graph_def)); current_graph_def = hoisted_graph_def; } *output_graph_def = current_graph_def; return Status::OK(); } // Converts any float ops that have eight-bit equivalents into their quantized // forms, so that as much calculation as possible is done in the lower-precision // format. Status QuantizeNodes(const GraphDef& input_graph_def, const TransformFuncContext& context, GraphDef* output_graph_def) { // Loop through all of the quantizable op types, and replace any occurrences // with equivalent sub-graphs with quantized ops at their core. For example // this one-input operation: // // Input(float) // | // v // Operation // | // v // (float) // // Will be turned into it's quantized equivalent: // // Input(float) ReshapeDims // +------v v-------------+ // | Reshape // | | // | | ReductionDims // | +-----+ | // | | +---c---------+ // | v v v v-------+ // | Min Max // | +----+ | // v v v--------+ // Quantize // | // v // QuantizedOperation // | | | // v v v // Dequantize // | // v // (float) // // This keeps the inputs and outputs visible to the rest of the graph in // float // and converts them down to quantized buffers internally for the // computation. // The result will end up with a lot of redundant dequantize/quantize pairs // between adjacent quantized ops, but a later pass removes these where it // can. std::set ops_to_ignore; if (context.params.count("ignore_op") > 0) { for (const string& name : context.params.at("ignore_op")) { ops_to_ignore.insert(name); } } const std::vector& op_list = GetQuantizedOpList(); string op_pattern; bool is_first = true; std::map op_map; for (const QuantizedOpInfo& op_info : op_list) { if (ops_to_ignore.count(op_info.float_name) == 0) { strings::StrAppend(&op_pattern, (is_first ? "" : "|"), op_info.float_name); op_map.insert({op_info.float_name, op_info}); is_first = false; } } // If input_min and input max have been passed in, then we convert all float // Placeholder nodes into quantized versions, with the supplied values as // their range. GraphDef placeholder_graph_def; TF_RETURN_IF_ERROR( QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def)); // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear // operations like Relu or MaxPool, move them up so that they're as close as // possible to ops with 32-bit outputs like BiasAdd or Conv2D. GraphDef hoisted_graph_def; TF_RETURN_IF_ERROR( HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def)); // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of // activation layers, into Requantize ops with those ranges instead. This // makes it easier to replace the dynamic range calculations that are used // by default. GraphDef converted_graph_def; TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context, &converted_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def)); // If fallback_min and fallback_max are set, then we'll use hardwired ranges // for all the 32-bit to 8-bit requantizations. float fallback_min; float fallback_max; bool has_fallback_range; TF_RETURN_IF_ERROR(ExtractRangeFromParams( context, "fallback_min", "fallback_max", &fallback_min, &fallback_max, &has_fallback_range)); // Replace all occurrences of the current float op with its quantized // equivalent. GraphDef quantized_graph_def; TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( converted_graph_def, {op_pattern}, [&op_map, fallback_min, fallback_max, has_fallback_range]( const NodeMatch& match, const std::set& input_nodes, const std::set& output_nodes, std::vector* new_nodes) { const NodeDef& float_node = match.node; const QuantizedOpInfo& op_info = op_map[float_node.op()]; DataTypeVector input_types; DataTypeVector output_types; TF_RETURN_IF_ERROR( GetInOutTypes(float_node, &input_types, &output_types)); bool are_all_float = true; for (int i = 0; i < float_node.input_size(); ++i) { // Skip any known non-float inputs. if (op_info.unquantized_inputs.count(i)) { continue; } if (input_types[i] != DT_FLOAT) { are_all_float = false; } } for (const DataType& output_type : output_types) { if (output_type != DT_FLOAT) { are_all_float = false; } } // This isn't a float op, so don't quantize it. if (!are_all_float) { CopyOriginalMatch(match, new_nodes); return Status::OK(); } string namespace_prefix = float_node.name() + "_eightbit"; // Quantize all of the inputs. std::vector quantized_input_names; for (int i = 0; i < float_node.input_size(); ++i) { // Skip any non-float inputs. if (op_info.unquantized_inputs.count(i)) { continue; } const string& input_name = float_node.input(i); string unique_input_name = namespace_prefix + "/" + UniqueNodeNameFromInput(input_name); // Add some common constants we need for reshaping inputs. NodeDef reshape_dims; reshape_dims.set_op("Const"); reshape_dims.set_name(unique_input_name + "/reshape_dims"); AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims); SetNodeAttr("dtype", DT_INT32, &reshape_dims); Tensor reshape_dims_tensor(DT_INT32, {1}); reshape_dims_tensor.flat()(0) = -1; SetNodeTensorAttr("value", reshape_dims_tensor, &reshape_dims); new_nodes->push_back(reshape_dims); NodeDef reduction_dims; reduction_dims.set_op("Const"); reduction_dims.set_name(unique_input_name + "/reduction_dims"); AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims); SetNodeAttr("dtype", DT_INT32, &reduction_dims); Tensor reduction_dims_tensor(DT_INT32, {1}); reduction_dims_tensor.flat()(0) = 0; SetNodeTensorAttr("value", reduction_dims_tensor, &reduction_dims); new_nodes->push_back(reduction_dims); NodeDef reshape_node; reshape_node.set_op("Reshape"); reshape_node.set_name(unique_input_name + "/reshape"); SetNodeAttr("T", DT_FLOAT, &reshape_node); AddNodeInput(input_name, &reshape_node); AddNodeInput(reshape_dims.name(), &reshape_node); new_nodes->push_back(reshape_node); NodeDef min_node; min_node.set_op("Min"); min_node.set_name(unique_input_name + "/min"); SetNodeAttr("T", DT_FLOAT, &min_node); SetNodeAttr("keep_dims", false, &min_node); AddNodeInput(reshape_node.name(), &min_node); AddNodeInput(reduction_dims.name(), &min_node); new_nodes->push_back(min_node); NodeDef max_node; max_node.set_op("Max"); max_node.set_name(unique_input_name + "/max"); SetNodeAttr("T", DT_FLOAT, &max_node); SetNodeAttr("keep_dims", false, &max_node); AddNodeInput(reshape_node.name(), &max_node); AddNodeInput(reduction_dims.name(), &max_node); new_nodes->push_back(max_node); NodeDef quantize_node; quantize_node.set_op("QuantizeV2"); quantize_node.set_name(unique_input_name + "/quantize"); SetNodeAttr("T", DT_QUINT8, &quantize_node); SetNodeAttr("mode", "MIN_FIRST", &quantize_node); AddNodeInput(input_name, &quantize_node); AddNodeInput(min_node.name(), &quantize_node); AddNodeInput(max_node.name(), &quantize_node); new_nodes->push_back(quantize_node); quantized_input_names.push_back(quantize_node.name()); } // Set up the quantized version of the current op. NodeDef quantized_main_node; quantized_main_node.set_op("Quantized" + float_node.op()); quantized_main_node.set_name(float_node.name() + "/eightbit"); for (const string& attr_to_copy : op_info.attrs_to_copy) { CopyNodeAttr(float_node, attr_to_copy, attr_to_copy, &quantized_main_node); } for (const std::pair& dtype_to_set : op_info.dtypes_to_set) { SetNodeAttr(dtype_to_set.first, dtype_to_set.second, &quantized_main_node); } int quantized_input_index = 0; for (int i = 0; i < float_node.input_size(); ++i) { if (op_info.unquantized_inputs.count(i)) { AddNodeInput(float_node.input(i), &quantized_main_node); } else { const string& quantized_input_name = quantized_input_names[quantized_input_index]; AddNodeInput(quantized_input_name + ":0", &quantized_main_node); ++quantized_input_index; } } if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) { for (const string& quantized_input_name : quantized_input_names) { AddNodeInput(quantized_input_name + ":1", &quantized_main_node); AddNodeInput(quantized_input_name + ":2", &quantized_main_node); } } else { for (const string& quantized_input_name : quantized_input_names) { AddNodeInput(quantized_input_name + ":1", &quantized_main_node); } for (const string& quantized_input_name : quantized_input_names) { AddNodeInput(quantized_input_name + ":2", &quantized_main_node); } } new_nodes->push_back(quantized_main_node); string eight_bit_node_name; if (op_info.output_bit_depth == DT_QINT32) { // Shrink the range of the output down from 32 bits to 8. string requantize_min_input; string requantize_max_input; if (has_fallback_range) { // Use constant values for the min/max range if they were given. NodeDef fallback_min_node; fallback_min_node.set_op("Const"); fallback_min_node.set_name(quantized_main_node.name() + "/fallback_min"); SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node); Tensor fallback_min_tensor(DT_FLOAT, {}); fallback_min_tensor.flat()(0) = fallback_min; SetNodeTensorAttr("value", fallback_min_tensor, &fallback_min_node); new_nodes->push_back(fallback_min_node); NodeDef fallback_max_node; fallback_max_node.set_op("Const"); fallback_max_node.set_name(quantized_main_node.name() + "/fallback_max"); SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node); Tensor fallback_max_tensor(DT_FLOAT, {}); fallback_max_tensor.flat()(0) = fallback_max; SetNodeTensorAttr("value", fallback_max_tensor, &fallback_max_node); new_nodes->push_back(fallback_max_node); requantize_min_input = fallback_min_node.name(); requantize_max_input = fallback_max_node.name(); } else { // Otherwise dynamically measure the range each time. NodeDef requant_range_node; requant_range_node.set_op("RequantizationRange"); requant_range_node.set_name(quantized_main_node.name() + "/requant_range"); SetNodeAttr("Tinput", DT_QINT32, &requant_range_node); AddNodeInput(quantized_main_node.name() + ":0", &requant_range_node); AddNodeInput(quantized_main_node.name() + ":1", &requant_range_node); AddNodeInput(quantized_main_node.name() + ":2", &requant_range_node); new_nodes->push_back(requant_range_node); requantize_min_input = requant_range_node.name() + ":0"; requantize_max_input = requant_range_node.name() + ":1"; } NodeDef requantize_node; requantize_node.set_op("Requantize"); requantize_node.set_name(quantized_main_node.name() + "/requantize"); SetNodeAttr("Tinput", DT_QINT32, &requantize_node); SetNodeAttr("out_type", DT_QUINT8, &requantize_node); AddNodeInput(quantized_main_node.name() + ":0", &requantize_node); AddNodeInput(quantized_main_node.name() + ":1", &requantize_node); AddNodeInput(quantized_main_node.name() + ":2", &requantize_node); AddNodeInput(requantize_min_input, &requantize_node); AddNodeInput(requantize_max_input, &requantize_node); new_nodes->push_back(requantize_node); eight_bit_node_name = requantize_node.name(); } else { eight_bit_node_name = quantized_main_node.name(); } // Convert the 8-bit result back into float for the final output. NodeDef dequantize_node; dequantize_node.set_op("Dequantize"); dequantize_node.set_name(float_node.name()); SetNodeAttr("T", DT_QUINT8, &dequantize_node); SetNodeAttr("mode", "MIN_FIRST", &dequantize_node); AddNodeInput(eight_bit_node_name + ":0", &dequantize_node); AddNodeInput(eight_bit_node_name + ":1", &dequantize_node); AddNodeInput(eight_bit_node_name + ":2", &dequantize_node); new_nodes->push_back(dequantize_node); return Status::OK(); }, {}, &quantized_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def)); // If we've ended up with two Requantize ops in a row (for example if there // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together, // using the trained range from the second op. GraphDef merged_graph_def; TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context, &merged_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def)); // There can be duplicate quantize nodes if multiple ops pull from a single // input, which makes it harder to remove redundant ones, so strip them out. GraphDef deduped_graph_def; TF_RETURN_IF_ERROR( MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def)); // Look for Dequantizes that immediately go into Quantizes, and remove them // since the two together cancel each other out. This allows us to keep the // data flow in eight bit where two adjacent ops are in eight bit, but still // keep interoperability with float ops. TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context, output_graph_def)); TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def)); return Status::OK(); } REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes); REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes); } // namespace graph_transforms } // namespace tensorflow