aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/lite/toco/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/args.h2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h18
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc86
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc16
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto8
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc35
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc22
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h2
9 files changed, 155 insertions, 35 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 398978b145..f696f4b845 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -238,6 +238,7 @@ cc_library(
"graph_transformations/merge_reshape_into_preceding_transpose.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
+ "graph_transformations/propagate_default_min_max.cc",
"graph_transformations/propagate_fake_quant_num_bits.cc",
"graph_transformations/propagate_fixed_sizes.cc",
"graph_transformations/quantization_util.cc",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 71e7318ac3..c9662d05ce 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -227,6 +227,8 @@ struct ParsedTocoFlags {
// TODO(aselle): command_line_flags doesn't support doubles
Arg<float> default_ranges_min = Arg<float>(0.);
Arg<float> default_ranges_max = Arg<float>(0.);
+ Arg<float> default_int16_ranges_min = Arg<float>(0.);
+ Arg<float> default_int16_ranges_max = Arg<float>(0.);
Arg<string> inference_type;
Arg<string> inference_input_type;
Arg<bool> drop_fake_quant = Arg<bool>(false);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index 56b3dec5c4..8075d0205d 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -190,6 +190,24 @@ DECLARE_GRAPH_TRANSFORMATION(Dequantize)
DECLARE_GRAPH_TRANSFORMATION(UnpartitionEmbeddingLookup)
DECLARE_GRAPH_TRANSFORMATION(ExperimentalShuffleFCWeights)
+class PropagateDefaultMinMax : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "PropagateDefaultMinMax"; }
+
+ bool has_any_ranges_defined() const { return !type_ranges_.empty(); }
+ void DefineTypeRange(ArrayDataType data_type, double min, double max) {
+ MinMax minmax;
+ minmax.min = min;
+ minmax.max = max;
+ type_ranges_.emplace_back(data_type, minmax);
+ }
+
+ private:
+ bool SetArrayMinMax(const string& array_name, Array* array);
+ std::vector<std::pair<ArrayDataType, MinMax>> type_ranges_;
+};
+
class ResolveReshapeAttributes : public GraphTransformation {
public:
bool Run(Model* model, std::size_t op_index) override;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
new file mode 100644
index 0000000000..50b90e7c2b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_default_min_max.cc
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// Propagates default min/max values to any operator input/output array that
+// is missing them.
+//
+// When provided a set of min/max values for uint8 arrays this will rescale
+// the values for other data types as required and preserving the floating point
+// range within the new type.
+bool PropagateDefaultMinMax::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+
+ bool did_change = false;
+
+ for (const auto& input : op->inputs) {
+ auto& input_array = model->GetArray(input);
+ if (!input_array.minmax && !input_array.buffer) {
+ did_change |= SetArrayMinMax(input, &input_array);
+ }
+ }
+
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (!output_array.minmax && !output_array.buffer) {
+ did_change |= SetArrayMinMax(output, &output_array);
+ }
+ }
+
+ return did_change;
+}
+
+// Sets the min/max on the given array, adjusting the reference_minmax for the
+// final data type of the array if it is already specified.
+bool PropagateDefaultMinMax::SetArrayMinMax(const string& array_name,
+ Array* array) {
+ CHECK(!array->minmax);
+
+ ArrayDataType quantized_data_type =
+ GetQuantizedDataType(*array, ArrayDataType::kUint8);
+ for (const auto& type_range : type_ranges_) {
+ if (type_range.first == quantized_data_type) {
+ array->GetOrCreateMinMax() = type_range.second;
+ break;
+ }
+ }
+ if (!array->minmax) {
+ AddMessageF(
+ "No defaults specified for quantized data type %s of array %s, "
+ "skipping",
+ ArrayDataTypeName(quantized_data_type), array_name);
+ return false;
+ }
+
+ AddMessageF("Adding default minmax %g,%g to array %s when quantized as %s",
+ array->GetMinMax().min, array->GetMinMax().max, array_name,
+ ArrayDataTypeName(quantized_data_type));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index d1d68b6b47..74f98c8452 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -61,11 +61,21 @@ bool ParseTocoFlagsFromCommandLineFlags(
Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
parsed_flags.default_ranges_min.default_value(),
"If defined, will be used as the default value for the min bound "
- "of min/max ranges used for quantization."),
+ "of min/max ranges used for quantization of uint8 arrays."),
Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
parsed_flags.default_ranges_max.default_value(),
"If defined, will be used as the default value for the max bound "
- "of min/max ranges used for quantization."),
+ "of min/max ranges used for quantization of uint8 arrays."),
+ Flag("default_int16_ranges_min",
+ parsed_flags.default_int16_ranges_min.bind(),
+ parsed_flags.default_int16_ranges_min.default_value(),
+ "If defined, will be used as the default value for the min bound "
+ "of min/max ranges used for quantization of int16 arrays."),
+ Flag("default_int16_ranges_max",
+ parsed_flags.default_int16_ranges_max.bind(),
+ parsed_flags.default_int16_ranges_max.default_value(),
+ "If defined, will be used as the default value for the max bound "
+ "of min/max ranges used for quantization of int16 arrays."),
Flag("inference_type", parsed_flags.inference_type.bind(),
parsed_flags.inference_type.default_value(),
"Target data type of arrays in the output file (for input_arrays, "
@@ -212,6 +222,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
+ READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
+ READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 751aca948c..869c512d93 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 15.
+// Next ID to use: 17.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -103,8 +103,14 @@ message TocoFlags {
// for experimentation purposes only and should not be used in production:
// they make it easy to quantize models, but the resulting quantized model
// will be inaccurate.
+ //
+ // These values only apply to arrays quantized with the kUint8 data type.
optional float default_ranges_min = 5;
optional float default_ranges_max = 6;
+ // Equivalent versions of default_ranges_min/_max for arrays quantized with
+ // the kInt16 data type.
+ optional float default_int16_ranges_min = 15;
+ optional float default_int16_ranges_max = 16;
// Ignore and discard FakeQuant nodes. For instance, that can be used to
// generate plain float code without fake-quantization from a quantized
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index b69852453c..89cb2f85f8 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <set>
+#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "tensorflow/contrib/lite/toco/allocate_transient_arrays.h"
#include "tensorflow/contrib/lite/toco/dump_graphviz.h"
@@ -270,10 +271,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
- // Fix any issues with IO edges. This must happen after any transform that
- // may modify the structure of the edges.
- FixEdgeArrays(model);
-
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
RunGraphTransformations(model,
@@ -287,16 +284,38 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
});
}
+ // Fix any issues with IO edges. This must happen after any transform that
+ // may modify the structure of the edges.
+ FixEdgeArrays(model);
+
if (quantize_output) {
+ // If the user specified default min/max ranges we need to set all arrays
+ // that didn't either have a min/max specified or get one set via
+ // HardcodeMinMax or PropagateFakeQuantNumBits. This may require running
+ // HardcodeMinMax to move changes through the graph as we make changes.
+ auto propagate_default_min_max =
+ absl::make_unique<PropagateDefaultMinMax>();
if (toco_flags.has_default_ranges_min() &&
toco_flags.has_default_ranges_max()) {
- UseDefaultMinMaxRangeValues(model, toco_flags.default_ranges_min(),
- toco_flags.default_ranges_max());
- // The new MinMax info may need to be propagated a bit.
+ propagate_default_min_max->DefineTypeRange(
+ ArrayDataType::kUint8, toco_flags.default_ranges_min(),
+ toco_flags.default_ranges_max());
+ }
+ if (toco_flags.has_default_int16_ranges_min() &&
+ toco_flags.has_default_int16_ranges_max()) {
+ propagate_default_min_max->DefineTypeRange(
+ ArrayDataType::kInt16, toco_flags.default_int16_ranges_min(),
+ toco_flags.default_int16_ranges_max());
+ }
+ if (propagate_default_min_max->has_any_ranges_defined()) {
RunGraphTransformations(
model, "default min-max range propagation graph transformations",
- {new HardcodeMinMax});
+ {
+ propagate_default_min_max.release(),
+ new HardcodeMinMax,
+ });
}
+
CheckIsReadyForQuantization(*model);
RunGraphTransformations(model, "quantization graph transformations",
{
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index ecac0c28a5..cf2cbeedc7 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -1474,28 +1474,6 @@ void CheckIsReadyForQuantization(const Model& model) {
}
}
-void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
- double default_ranges_max) {
- for (const auto& op : model->operators) {
- for (const auto& input : op->inputs) {
- auto& input_array = model->GetArray(input);
- if (!input_array.minmax && !input_array.buffer) {
- auto& minmax = input_array.GetOrCreateMinMax();
- minmax.min = default_ranges_min;
- minmax.max = default_ranges_max;
- }
- }
- for (const auto& output : op->outputs) {
- auto& output_array = model->GetArray(output);
- if (!output_array.minmax && !output_array.buffer) {
- auto& minmax = output_array.GetOrCreateMinMax();
- minmax.min = default_ranges_min;
- minmax.max = default_ranges_max;
- }
- }
- }
-}
-
int ElementSize(ArrayDataType data_type) {
switch (data_type) {
case ArrayDataType::kBool:
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index 4c705f4e5f..5cc15fa57b 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -188,8 +188,6 @@ T ConvertOperator(Operator* o, OperatorType type) {
}
void CheckIsReadyForQuantization(const Model& model);
-void UseDefaultMinMaxRangeValues(Model* model, double default_ranges_min,
- double default_ranges_max);
bool ReshapeIsEquivalentToTranspose(const Model& model,
const TensorFlowReshapeOperator* op,