diff options
-rw-r--r-- | tensorflow/contrib/lite/toco/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/args.h | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h | 18 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc | 86 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_cmdline_flags.cc | 16 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_flags.proto | 8 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/toco_tooling.cc | 35 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 22 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.h | 2 |
9 files changed, 155 insertions, 35 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 398978b145..f696f4b845 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -238,6 +238,7 @@ cc_library( "graph_transformations/merge_reshape_into_preceding_transpose.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", + "graph_transformations/propagate_default_min_max.cc", "graph_transformations/propagate_fake_quant_num_bits.cc", "graph_transformations/propagate_fixed_sizes.cc", "graph_transformations/quantization_util.cc", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 71e7318ac3..c9662d05ce 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -227,6 +227,8 @@ struct ParsedTocoFlags { // TODO(aselle): command_line_flags doesn't support doubles Arg<float> default_ranges_min = Arg<float>(0.); Arg<float> default_ranges_max = Arg<float>(0.); + Arg<float> default_int16_ranges_min = Arg<float>(0.); + Arg<float> default_int16_ranges_max = Arg<float>(0.); Arg<string> inference_type; Arg<string> inference_input_type; Arg<bool> drop_fake_quant = Arg<bool>(false); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 56b3dec5c4..8075d0205d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -190,6 +190,24 @@ DECLARE_GRAPH_TRANSFORMATION(Dequantize) DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup) DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights) +class PropagateDefaultMinMax : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "PropagateDefaultMinMax"; } + + bool has_any_ranges_defined() const { return !type_ranges_.empty(); } + void DefineTypeRange(ArrayDataType data_type, double min, double max) { + MinMax minmax; + minmax.min = min; + minmax.max = max; + type_ranges_.emplace_back(data_type, minmax); + } + + private: + bool SetArrayMinMax(const string& array_name, Array* array); + std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_; +}; + class ResolveReshapeAttributes : public GraphTransformation { public: bool Run(Model* model, std::size_t op_index) override; diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc new file mode 100644 index 0000000000..50b90e7c2b --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// Propagates default min/max values to any operator input/output array that +// is missing them. +// +// When provided a set of min/max values for uint8 arrays this will rescale +// the values for other data types as required and preserving the floating point +// range within the new type. +bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* op = it->get(); + + bool did_change = false; + + for (const auto& input : op->inputs) { + auto& input_array = model->GetArray(input); + if (!input_array.minmax && !input_array.buffer) { + did_change |= SetArrayMinMax(input, &input_array); + } + } + + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (!output_array.minmax && !output_array.buffer) { + did_change |= SetArrayMinMax(output, &output_array); + } + } + + return did_change; +} + +// Sets the min/max on the given array, adjusting the reference_minmax for the +// final data type of the array if it is already specified. +bool PropagateDefaultMinMax::SetArrayMinMax(const string& array_name, + Array* array) { + CHECK(!array->minmax); + + ArrayDataType quantized_data_type = + GetQuantizedDataType(*array, ArrayDataType::kUint8); + for (const auto& type_range : type_ranges_) { + if (type_range.first == quantized_data_type) { + array->GetOrCreateMinMax() = type_range.second; + break; + } + } + if (!array->minmax) { + AddMessageF( + "No defaults specified for quantized data type %s of array %s, " + "skipping", + ArrayDataTypeName(quantized_data_type), array_name); + return false; + } + + AddMessageF("Adding default minmax %g,%g to array %s when quantized as %s", + array->GetMinMax().min, array->GetMinMax().max, array_name, + ArrayDataTypeName(quantized_data_type)); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index d1d68b6b47..74f98c8452 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -61,11 +61,21 @@ bool ParseTocoFlagsFromCommandLineFlags( Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(), parsed_flags.default_ranges_min.default_value(), "If defined, will be used as the default value for the min bound " - "of min/max ranges used for quantization."), + "of min/max ranges used for quantization of uint8 arrays."), Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(), parsed_flags.default_ranges_max.default_value(), "If defined, will be used as the default value for the max bound " - "of min/max ranges used for quantization."), + "of min/max ranges used for quantization of uint8 arrays."), + Flag("default_int16_ranges_min", + parsed_flags.default_int16_ranges_min.bind(), + parsed_flags.default_int16_ranges_min.default_value(), + "If defined, will be used as the default value for the min bound " + "of min/max ranges used for quantization of int16 arrays."), + Flag("default_int16_ranges_max", + parsed_flags.default_int16_ranges_max.bind(), + parsed_flags.default_int16_ranges_max.default_value(), + "If defined, will be used as the default value for the max bound " + "of min/max ranges used for quantization of int16 arrays."), Flag("inference_type", parsed_flags.inference_type.bind(), parsed_flags.inference_type.default_value(), "Target data type of arrays in the output file (for input_arrays, " @@ -212,6 +222,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone); READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone); READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone); + READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone); + READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone); READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone); READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 751aca948c..869c512d93 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 15. +// Next ID to use: 17. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -103,8 +103,14 @@ message TocoFlags { // for experimentation purposes only and should not be used in production: // they make it easy to quantize models, but the resulting quantized model // will be inaccurate. + // + // These values only apply to arrays quantized with the kUint8 data type. optional float default_ranges_min = 5; optional float default_ranges_max = 6; + // Equivalent versions of default_ranges_min/_max for arrays quantized with + // the kInt16 data type. + optional float default_int16_ranges_min = 15; + optional float default_int16_ranges_max = 16; // Ignore and discard FakeQuant nodes. For instance, that can be used to // generate plain float code without fake-quantization from a quantized diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index b69852453c..89cb2f85f8 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <set> +#include "absl/memory/memory.h" #include "absl/strings/str_join.h" #include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" @@ -270,10 +271,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) { RunGraphTransformations(model, "general graph transformations", transformations); - // Fix any issues with IO edges. This must happen after any transform that - // may modify the structure of the edges. - FixEdgeArrays(model); - if (quantize_output) { if (toco_flags.propagate_fake_quant_num_bits()) { RunGraphTransformations(model, @@ -287,16 +284,38 @@ void Transform(const TocoFlags& toco_flags, Model* model) { }); } + // Fix any issues with IO edges. This must happen after any transform that + // may modify the structure of the edges. + FixEdgeArrays(model); + if (quantize_output) { + // If the user specified default min/max ranges we need to set all arrays + // that didn't either have a min/max specified or get one set via + // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running + // HardcodeMinMax to move changes through the graph as we make changes. + auto propagate_default_min_max = + absl::make_unique<PropagateDefaultMinMax>(); if (toco_flags.has_default_ranges_min() && toco_flags.has_default_ranges_max()) { - UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(), - toco_flags.default_ranges_max()); - // The new MinMax info may need to be propagated a bit. + propagate_default_min_max->DefineTypeRange( + ArrayDataType::kUint8, toco_flags.default_ranges_min(), + toco_flags.default_ranges_max()); + } + if (toco_flags.has_default_int16_ranges_min() && + toco_flags.has_default_int16_ranges_max()) { + propagate_default_min_max->DefineTypeRange( + ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(), + toco_flags.default_int16_ranges_max()); + } + if (propagate_default_min_max->has_any_ranges_defined()) { RunGraphTransformations( model, "default min-max range propagation graph transformations", - {new HardcodeMinMax}); + { + propagate_default_min_max.release(), + new HardcodeMinMax, + }); } + CheckIsReadyForQuantization(*model); RunGraphTransformations(model, "quantization graph transformations", { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index ecac0c28a5..cf2cbeedc7 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -1474,28 +1474,6 @@ void CheckIsReadyForQuantization(const Model& model) { } } -void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, - double default_ranges_max) { - for (const auto& op : model->operators) { - for (const auto& input : op->inputs) { - auto& input_array = model->GetArray(input); - if (!input_array.minmax && !input_array.buffer) { - auto& minmax = input_array.GetOrCreateMinMax(); - minmax.min = default_ranges_min; - minmax.max = default_ranges_max; - } - } - for (const auto& output : op->outputs) { - auto& output_array = model->GetArray(output); - if (!output_array.minmax && !output_array.buffer) { - auto& minmax = output_array.GetOrCreateMinMax(); - minmax.min = default_ranges_min; - minmax.max = default_ranges_max; - } - } - } -} - int ElementSize(ArrayDataType data_type) { switch (data_type) { case ArrayDataType::kBool: diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 4c705f4e5f..5cc15fa57b 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -188,8 +188,6 @@ T ConvertOperator(Operator* o, OperatorType type) { } void CheckIsReadyForQuantization(const Model& model); -void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min, - double default_ranges_max); bool ReshapeIsEquivalentToTranspose(const Model& model, const TensorFlowReshapeOperator* op, |