aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 11:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 11:55:26 -0700
commitd7b6cb66c0fc346cf55020042931c07208713c60 (patch)
tree9024111ebf15d12a631ffd7e176b9da7459dd5a0
parent1192c1662c5c98f55805450b4619ac2bc9c6908c (diff)
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
-rw-r--r--tensorflow/contrib/lite/toco/BUILD5
-rw-r--r--tensorflow/contrib/lite/toco/args.h1
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc12
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc2
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h20
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc1
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc307
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc88
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h25
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc139
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc86
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc25
-rw-r--r--tensorflow/contrib/lite/toco/toco_cmdline_flags.cc7
-rw-r--r--tensorflow/contrib/lite/toco/toco_flags.proto11
-rw-r--r--tensorflow/contrib/lite/toco/toco_tooling.cc26
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc73
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.h18
17 files changed, 702 insertions, 144 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD
index 5b86e4e5ae..398978b145 100644
--- a/tensorflow/contrib/lite/toco/BUILD
+++ b/tensorflow/contrib/lite/toco/BUILD
@@ -238,6 +238,7 @@ cc_library(
"graph_transformations/merge_reshape_into_preceding_transpose.cc",
"graph_transformations/propagate_activation_function_into_constants.cc",
"graph_transformations/propagate_array_data_types.cc",
+ "graph_transformations/propagate_fake_quant_num_bits.cc",
"graph_transformations/propagate_fixed_sizes.cc",
"graph_transformations/quantization_util.cc",
"graph_transformations/quantization_util.h",
@@ -249,6 +250,7 @@ cc_library(
"graph_transformations/remove_trivial_binary.cc",
"graph_transformations/remove_trivial_concatenation.cc",
"graph_transformations/remove_trivial_concatenation_input.cc",
+ "graph_transformations/remove_trivial_fake_quant.cc",
"graph_transformations/remove_trivial_passthrough.cc",
"graph_transformations/remove_trivial_passthrough.h",
"graph_transformations/remove_trivial_quantized_activation_func.cc",
@@ -303,7 +305,7 @@ cc_library(
":runtime",
":toco_port",
":tooling_util",
- ":types_proto_cc",
+ "//tensorflow/contrib/lite/kernels/internal:quantization_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -378,7 +380,6 @@ cc_library(
":toco_graphviz_dump_options",
":toco_port",
":types_proto_cc",
- "//tensorflow/contrib/lite/kernels/internal:quantization_util",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@protobuf_archive//:protobuf_headers",
diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h
index 7a7059e357..71e7318ac3 100644
--- a/tensorflow/contrib/lite/toco/args.h
+++ b/tensorflow/contrib/lite/toco/args.h
@@ -237,6 +237,7 @@ struct ParsedTocoFlags {
Arg<string> input_types;
Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
Arg<bool> drop_control_dependency = Arg<bool>(false);
+ Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
};
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index c8352741b4..c289ddcd92 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -95,10 +95,8 @@ Color GetColorForArray(const Model& model, const string& array_name) {
array_name == dump_options.graphviz_last_array) {
return Color(0x9E, 0x9E, 0x9E);
}
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return Color(0x9E, 0x9E, 0x9E);
- }
+ if (IsOutputArray(model, array_name)) {
+ return Color(0x9E, 0x9E, 0x9E);
}
// Remaining arrays are intermediate activation arrays.
// Lighter tone of the same grey as for input/output arrays:
@@ -119,6 +117,12 @@ void AppendArrayVal(string* string, Array const& array, int index) {
return;
}
AppendF(string, "%d", data[index]);
+ } else if (array.buffer->type == ArrayDataType::kInt16) {
+ const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data;
+ if (index >= data.size()) {
+ return;
+ }
+ AppendF(string, "%d", data[index]);
} else if (array.buffer->type == ArrayDataType::kInt32) {
const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data;
if (index >= data.size()) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
index badefeca88..708ecf6e0a 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -47,7 +47,7 @@ bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
op->type == OperatorType::kDepthwiseConv ||
op->type == OperatorType::kFullyConnected) {
if (ProcessLinearOperator(model, op)) {
- AddMessageF("Added bias vector to %s", LogName(*op));
+ AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]);
return true;
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
index dbf029a853..56b3dec5c4 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -135,6 +135,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv)
DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants)
DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
+DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits);
DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
DECLARE_GRAPH_TRANSFORMATION(Quantize)
@@ -144,6 +145,7 @@ DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedMinMax)
@@ -163,7 +165,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
-DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose)
@@ -210,6 +211,23 @@ class RemoveTrivialReshape : public GraphTransformation {
bool treat_expand_dims_as_trivial_ = false;
};
+class ResolveConstantFakeQuant : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "ResolveConstantFakeQuant"; }
+
+ // True if the num_bits should adjust the final data type.
+ bool propagate_fake_quant_num_bits() const {
+ return propagate_fake_quant_num_bits_;
+ }
+ void set_propagate_fake_quant_num_bits(bool val) {
+ propagate_fake_quant_num_bits_ = val;
+ }
+
+ private:
+ bool propagate_fake_quant_num_bits_ = false;
+};
+
#undef DECLARE_GRAPH_TRANSFORMATION
} // end namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 183b3d3f2e..45d9f73a1e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
new file mode 100644
index 0000000000..0bce183c18
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -0,0 +1,307 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ChangeArrayDataType(GraphTransformation* transformation, Array* array,
+ ArrayDataType new_data_type,
+ const MinMax* new_minmax) {
+ // Ensure the array ends up in the new type (if it hasn't yet been quantized).
+ array->final_data_type = new_data_type;
+
+ if (array->minmax && array->quantization_params) {
+ // The array is already quantized and has min/max info.
+ // As we are changing the data type we need to fix up the existing min/max
+ // to the new data type range.
+
+ double old_quantized_min, old_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(
+ array->data_type, &old_quantized_min, &old_quantized_max))
+ << "Existing data type is not quantized: "
+ << ArrayDataTypeName(array->data_type);
+ double new_quantized_min, new_quantized_max;
+ CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min,
+ &new_quantized_max))
+ << "New data type is not quantized: "
+ << ArrayDataTypeName(new_data_type);
+
+ // Compute new minmax values.
+ double min = (old_quantized_min - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ double max =
+ (old_quantized_max + 1 - array->quantization_params->zero_point) *
+ array->quantization_params->scale;
+ max = max - 1.0 / (new_quantized_max + 1);
+
+ auto& array_minmax = array->GetOrCreateMinMax();
+ transformation->AddMessageF(
+ "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
+ array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
+ ArrayDataTypeName(new_data_type));
+
+ array_minmax.min = min;
+ array_minmax.max = max;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
+ array_minmax, array->quantization_params.get());
+
+ // Directly change the type as the array was already quantized.
+ array->data_type = new_data_type;
+ } else {
+ // Array has not yet been quantized so we can just set the final data type
+ // and assign the new min/max value (if provided).
+ CHECK(!array->quantization_params);
+
+ if (!array->minmax && new_minmax) {
+ transformation->AddMessageF("Forcing new minmax to %g,%g (%s)",
+ new_minmax->min, new_minmax->max,
+ ArrayDataTypeName(new_data_type));
+ auto& array_minmax = array->GetOrCreateMinMax();
+ array_minmax.min = new_minmax->min;
+ array_minmax.max = new_minmax->max;
+ }
+ }
+}
+
+// Returns true if the op blocks our backward recursive data type propagation.
+bool DoesOpBlockBackwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kConcatenation:
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Concat shouldn't block propagation, but we do expect that all inputs
+ // have the same range.
+ return false;
+ case OperatorType::kDequantize:
+ // Dequantize ops are inserted between the value we care about and the
+ // FakeQuant so make sure we move across them.
+ case OperatorType::kGather:
+ // Gathers need their parameters changed to the appropriate data type.
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Reshapes and transposes don't change values.
+ return false;
+ default:
+ return true;
+ }
+}
+
+// Returns true if the input of an op blocks our backward recursive data type
+// propagation.
+bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) {
+ switch (op.type) {
+ case OperatorType::kGather:
+ // Ignore gather indices.
+ return input_index != 0;
+ break;
+ case OperatorType::kTensorFlowReshape:
+ case OperatorType::kTranspose:
+ // Ignore reshape/transpose shapes/dimensions.
+ return input_index != 0;
+ default:
+ return false;
+ }
+}
+
+// Propagates the data type up into the input arrays if they are model inputs
+// that may need their type changed. May act recursively if the inputs are
+// produced by ops that we can move over (such as Dequantize).
+bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type,
+ const MinMax& new_minmax) {
+ bool did_change = false;
+ for (int input_index = 0; input_index < op->inputs.size(); ++input_index) {
+ const auto& input = op->inputs[input_index];
+ auto& input_array = model->GetArray(input);
+ if (input_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ // Prevent moving into constant param args that we don't want to modify.
+ if (DoesOpInputBlockBackwardPropagation(*op, input_index)) {
+ continue;
+ }
+
+ if (input_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting input final data type of array %s from %s to %s", input,
+ ArrayDataTypeName(input_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &input_array, new_data_type,
+ &new_minmax);
+
+ // Walk up into all ops producing the inputs to this op.
+ for (auto& producing_op : model->operators) {
+ if (!DoesOpBlockBackwardPropagation(*producing_op)) {
+ for (const auto& output : producing_op->outputs) {
+ if (input == output) {
+ did_change |= RecursivelyBackwardPropagateDataType(
+ transformation, model, producing_op.get(), new_data_type,
+ new_minmax);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+// Returns true if the op blocks our forward recursive data type propagation.
+bool DoesOpBlockForwardPropagation(const Operator& op) {
+ switch (op.type) {
+ case OperatorType::kFakeQuant:
+ // Always stop at another FakeQuant, as it will likely have different
+ // parameters.
+ return true;
+ default:
+ return false;
+ }
+}
+
+// Recurses down the graph setting the data type of all arrays until an operator
+// that blocks propagation (like another FakeQuant) or a final_data_type is
+// already specified.
+bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation,
+ Model* model, Operator* op,
+ ArrayDataType new_data_type) {
+ bool did_change = false;
+ for (const auto& output : op->outputs) {
+ auto& output_array = model->GetArray(output);
+ if (output_array.final_data_type == new_data_type) {
+ // Final data type is already - skip.
+ continue;
+ }
+
+ if (output_array.final_data_type == ArrayDataType::kNone ||
+ output_array.final_data_type != new_data_type) {
+ transformation->AddMessageF(
+ "Adjusting output final data type of array %s from %s to %s", output,
+ ArrayDataTypeName(output_array.final_data_type),
+ ArrayDataTypeName(new_data_type));
+ did_change = true;
+ ChangeArrayDataType(transformation, &output_array, new_data_type,
+ nullptr);
+
+ // Walk down into all ops consuming the output of this op.
+ for (auto& consuming_op : model->operators) {
+ if (!DoesOpBlockForwardPropagation(*consuming_op)) {
+ for (const auto& input : consuming_op->inputs) {
+ if (input == output) {
+ did_change |= RecursivelyForwardPropagateDataType(
+ transformation, model, consuming_op.get(), new_data_type);
+ }
+ }
+ }
+ }
+ }
+ }
+ return did_change;
+}
+
+} // namespace
+
+// Propagates the num_bits on a FakeQuant operator into the final data types
+// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know
+// the output must be int16 and assume all inputs up until the preceding op are
+// also 16.
+//
+// This can be thought of as a bidirectional flood-fill of the num_bits implied
+// final_data_type that terminates at other FakeQuant ops (and a few others as
+// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation).
+// Once all FakeQuant ops have been visted the arrays should all have
+// appropriate final_data_types if the source graph was annotated with the
+// proper FakeQuant ops.
+//
+// Annotating a graph requires following a few hard rules:
+// - every input MUST have a FakeQuant immediately following it
+// - every output MUST have a FakeQuant immediately preceding it
+// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant
+// immediately following it
+// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc)
+// MUST have FakeQuants between them and the consuming op
+// Additional FakeQuants may be used if desired, especially in areas that may
+// suffer from large precision changes - such as between a Softmax and a
+// FullyConnected. Only by validating accuracy differences between float
+// inference with the FakeQuant ops simulating quantization and the actually
+// quantized graph can you be sure the appropriate FakeQuant ops are present.
+//
+// You can tell if you're missing some FakeQuants by looking for warnings from
+// quantize.cc about minmax ranges being determined by the contents of constant
+// arrays. This will almost never produce functional models during inference.
+//
+// As this op may change the data types and ranges of input and output arrays
+// downstream tools must also be sure to parse the output model flags to get the
+// post-Transform values that may have changed due to this transformation.
+//
+// This isn't a GraphTransformation in the traditional respect as it affects ops
+// outside of the one under transformation. This is primarily so that we can
+// utilize the graph traversal and repeated pass system underlying the
+// transformation system to exhaustively find all FakeQuant ops. It also gets us
+// nice logging and integration with the graphviz video dumping mode.
+// In general you should not copy this style of transformation and stick to
+// local-only changes as seen in the other transformations.
+bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
+
+ ArrayDataType quantized_data_type = ArrayDataType::kNone;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring",
+ LogName(*op), fakequant_op->num_bits);
+ return false;
+ }
+ const auto& final_minmax = *fakequant_op->minmax;
+
+ AddMessageF(
+ "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s",
+ LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max,
+ ArrayDataTypeName(quantized_data_type));
+
+ bool did_change = false;
+
+ // Propagate the FakeQuant information backward up the graph.
+ // This will possibly adjust input arrays or constant types (like Gather).
+ did_change |= RecursivelyBackwardPropagateDataType(
+ this, model, op, quantized_data_type, final_minmax);
+
+ // Propagate the FakeQuant information forward down the graph.
+ // This will possibly adjust output arrays.
+ did_change |=
+ RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type);
+
+ return did_change;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
index e080df4bed..d74cad9a62 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
@@ -22,6 +22,20 @@ limitations under the License.
namespace toco {
+bool InferQuantizedDataTypeFromFakeQuant(
+ const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) {
+ if (op.num_bits <= 8) {
+ *out_quantized_data_type = ArrayDataType::kUint8;
+ return true;
+ } else if (op.num_bits <= 16) {
+ *out_quantized_data_type = ArrayDataType::kInt16;
+ return true;
+ } else {
+ *out_quantized_data_type = ArrayDataType::kNone;
+ return false;
+ }
+}
+
bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
double* out_min_value,
double* out_max_value) {
@@ -103,6 +117,80 @@ void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
}
}
+namespace {
+
+template <ArrayDataType A>
+std::unique_ptr<GenericBuffer> QuantizeBuffer(
+ const GenericBuffer& buffer,
+ const QuantizationParams& quantization_params) {
+ const auto inverse_scale = 1. / quantization_params.scale;
+ CHECK(buffer.type == ArrayDataType::kFloat);
+ const auto& float_buffer =
+ static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
+ auto* quantized_buffer = new Buffer<A>;
+ quantized_buffer->data.resize(float_buffer.data.size());
+ for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
+ const float src_val = float_buffer.data[i];
+ double scaled_val; // Astonishingly, using 'float' degrades accuracy just
+ // enough to make a few tests fail!
+ if (quantization_params.scale == 0) {
+ CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
+ << "so all its values should be 0.";
+ scaled_val = quantization_params.zero_point;
+ } else {
+ scaled_val = quantization_params.zero_point + inverse_scale * src_val;
+ }
+ quantized_buffer->data[i] =
+ tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ }
+ return std::unique_ptr<GenericBuffer>(quantized_buffer);
+}
+
+template <ArrayDataType A>
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name,
+ const QuantizationParams& quantization_params) {
+ auto& array = model->GetArray(name);
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ CHECK(!array.quantization_params);
+ array.GetOrCreateQuantizationParams() = quantization_params;
+ if (array.buffer) {
+ array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ }
+ array.data_type = A;
+ array.final_data_type = A;
+ transformation->AddMessageF(
+ "Quantized array %s to %s zero_point=%g, scale=%g", name,
+ ArrayDataTypeName(array.data_type), quantization_params.zero_point,
+ quantization_params.scale);
+}
+
+} // namespace
+
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params) {
+ ArrayDataType adjusted_data_type = quantized_data_type;
+ auto& array = model->GetArray(name);
+ if (array.final_data_type == ArrayDataType::kInt16) {
+ adjusted_data_type = array.final_data_type;
+ }
+
+ switch (adjusted_data_type) {
+ case ArrayDataType::kUint8:
+ return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt16:
+ return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt32:
+ return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
+ quantization_params);
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
const Array& array, double clamp_min,
double clamp_max) {
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
index 35fb310777..79a2ce7e50 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
@@ -15,11 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
namespace toco {
+// Gets the target quantized data type of an array based on the fake quant op.
+// For example, if the num_bits is 8 the data type will be kUint8.
+bool InferQuantizedDataTypeFromFakeQuant(
+ const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type);
+
// Gets the min/max numerical range for the given quantized data type.
// For example, kUint8 will return [0,255].
// Returns true if the ranges were set and false if the type is not quantized.
@@ -32,11 +38,28 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
ArrayDataType GetQuantizedDataType(const Array& array,
ArrayDataType default_type);
-// Gets the quantization params for the array with the given data type and
+// Returns the quantization params for the array with the given data type and
// minmax.
void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
QuantizationParams* quantization_params);
+// Returns the quantization params for the data type and minmax values.
+template <ArrayDataType A>
+void GetQuantizationParamsFromMinMax(const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ using Integer = DataType<A>;
+ const double rmin = minmax.min;
+ const double rmax = minmax.max;
+ *quantization_params =
+ ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
+}
+
+// Quantizes an array by setting its data type and (if constant) quantizing
+// all values in the array.
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params);
+
// Returns true if the given array, when quantized, contains only values between
// the provided clamp min/max.
// Either clamp_min or clamp_max may be +/-infinity to indicate that the value
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index d6cae3cdbf..fa46e6bc38 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -57,72 +57,6 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kTranspose || type == OperatorType::kMean;
}
-template <ArrayDataType A>
-std::unique_ptr<GenericBuffer> QuantizeBuffer(
- const GenericBuffer& buffer,
- const QuantizationParams& quantization_params) {
- const auto inverse_scale = 1. / quantization_params.scale;
- CHECK(buffer.type == ArrayDataType::kFloat);
- const auto& float_buffer =
- static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
- auto* quantized_buffer = new Buffer<A>;
- quantized_buffer->data.resize(float_buffer.data.size());
- for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
- const float src_val = float_buffer.data[i];
- double scaled_val; // Astonishingly, using 'float' degrades accuracy just
- // enough to make a few tests fail!
- if (quantization_params.scale == 0) {
- CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
- << "so all its values should be 0.";
- scaled_val = quantization_params.zero_point;
- } else {
- scaled_val = quantization_params.zero_point + inverse_scale * src_val;
- }
- quantized_buffer->data[i] =
- tflite::SafeCast<DataType<A>>(std::round(scaled_val));
- }
- return std::unique_ptr<GenericBuffer>(quantized_buffer);
-}
-
-template <ArrayDataType A>
-void QuantizeArray(GraphTransformation* transformation, Model* model,
- const string& name,
- const QuantizationParams& quantization_params) {
- auto& array = model->GetArray(name);
- CHECK(array.data_type == ArrayDataType::kFloat);
- CHECK(!array.quantization_params);
- array.GetOrCreateQuantizationParams() = quantization_params;
- if (array.buffer) {
- array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
- }
- array.data_type = A;
- transformation->AddMessageF("Quantized array %s", name);
-}
-
-void QuantizeArray(GraphTransformation* transformation, Model* model,
- const string& name, ArrayDataType quantized_data_type,
- const QuantizationParams& quantization_params) {
- ArrayDataType adjusted_data_type = quantized_data_type;
- auto& array = model->GetArray(name);
- if (array.final_data_type == ArrayDataType::kInt16) {
- adjusted_data_type = array.final_data_type;
- }
-
- switch (adjusted_data_type) {
- case ArrayDataType::kUint8:
- return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
- quantization_params);
- case ArrayDataType::kInt16:
- return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name,
- quantization_params);
- case ArrayDataType::kInt32:
- return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
- quantization_params);
- default:
- LOG(FATAL) << "Unhandled case.";
- }
-}
-
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
auto& array = model->GetArray(array_name);
// Normally we should have a MinMax recorded on this Array,
@@ -245,6 +179,8 @@ bool ChooseQuantizationForOperatorInput(
const auto& input_weights = model->GetArray(op.inputs[weights_input_index]);
if (!input_activations.quantization_params ||
!input_weights.quantization_params) {
+ transformation->AddMessageF(
+ "Input array %s is a bias vector but has no qparams", input);
return false;
}
const auto input_activations_scale =
@@ -366,6 +302,9 @@ bool ChooseQuantizationForOperatorOutput(
const auto& output = op.outputs[output_index];
auto& array = model->GetArray(output);
if (array.data_type != ArrayDataType::kFloat) {
+ transformation->AddMessageF("Array data type already set to %s, final=%s",
+ ArrayDataTypeName(array.data_type),
+ ArrayDataTypeName(array.final_data_type));
return false;
}
*quantized_data_type = model->GetArray(op.inputs[0]).data_type;
@@ -427,29 +366,22 @@ bool ChooseQuantizationForOperatorOutput(
// Fixes array minmax info to match the quantization parameters.
// This is required for when quantization parameters change for an array during
// quantization (such as ChooseQuantizationForOperatorOutput).
-void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
+void FixMinMaxPostQuantization(GraphTransformation* transformation,
+ ArrayDataType quantized_data_type,
const QuantizationParams& quantization_params,
MinMax* minmax) {
- double qmin, qmax;
- switch (quantized_data_type) {
- case ArrayDataType::kUint8:
- qmin = 0;
- qmax = 255;
- break;
- case ArrayDataType::kInt16:
- qmin = -32768;
- qmax = 32767;
- break;
- default:
- // No update required.
- return;
+ double quantized_min, quantized_max;
+ if (!GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min,
+ &quantized_max)) {
+ // Not quantized - no update required.
+ return;
}
// Compute new minmax values.
- double min =
- (qmin - quantization_params.zero_point) * quantization_params.scale;
- double max =
- (qmax - quantization_params.zero_point) * quantization_params.scale;
+ double min = (quantized_min - quantization_params.zero_point) *
+ quantization_params.scale;
+ double max = (quantized_max - quantization_params.zero_point) *
+ quantization_params.scale;
// If we are close to the existing minmax values don't bother changing them.
// This prevents propagating small floating point precision errors.
@@ -457,6 +389,9 @@ void FixMinMaxPostQuantization(ArrayDataType quantized_data_type,
const double width = max - min;
if (std::abs(min - minmax->min) > kMinMaxThreshold * width ||
std::abs(max - minmax->max) > kMinMaxThreshold * width) {
+ transformation->AddMessageF(
+ "Adjusting min/max from %g,%g to %g,%g to match quantization params",
+ minmax->min, minmax->max, min, max);
minmax->min = min;
minmax->max = max;
}
@@ -566,10 +501,33 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
// input instead.
for (int i = 0; i < model->flags.output_arrays_size(); i++) {
if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
- model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ // TODO(b/78013785): never rename output arrays.
+ if (IsInputArray(*model, dequantize_op->inputs[0])) {
+ // The op input is an input array and the output is an output
+ // array and we can't have an array be both. Insert a copy
+ // op to ensure the two arrays stay separate.
+ AddMessageF(
+ "Tried to rename output array %d while removing dequant "
+ "op %s but array is also an input; inserting copy %s "
+ "-> %s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ InsertCopyOperator(model, dequantize_op->inputs[0],
+ dequantize_op->outputs[0]);
+ } else {
+ // Op output is strictly used as an output array, so we can
+ // just rename the array and directly bypass the op.
+ AddMessageF(
+ "Renaming output array %d after removing dequant op %s: "
+ "%s -> %s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantize_op->inputs[0]);
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ model->EraseArray(dequantize_op->outputs[0]);
+ }
+ break;
}
}
- model->EraseArray(dequantize_op->outputs[0]);
model->operators.erase(dequantize_it);
}
changed = true;
@@ -615,7 +573,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
CHECK(output_array.minmax)
<< "Output array named " << output << " lacks minmax";
auto& output_minmax = output_array.GetMinMax();
- FixMinMaxPostQuantization(quantized_data_type, quantization_params,
+ FixMinMaxPostQuantization(this, quantized_data_type, quantization_params,
&output_minmax);
QuantizeArray(this, model, output, quantized_data_type,
@@ -626,6 +584,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
auto& dequantized_output_array =
model->GetOrCreateArray(dequantized_output);
dequantized_output_array.data_type = ArrayDataType::kFloat;
+ dequantized_output_array.final_data_type = output_array.data_type;
auto& dequantized_output_minmax =
dequantized_output_array.GetOrCreateMinMax();
dequantized_output_minmax.min = output_minmax.min;
@@ -642,6 +601,12 @@ bool Quantize::Run(Model* model, std::size_t op_index) {
dequantize_op->outputs = {dequantized_output};
for (int i = 0; i < model->flags.output_arrays_size(); i++) {
if (model->flags.output_arrays(i) == output) {
+ // TODO(b/78013785): never rename output arrays.
+ AddMessageF(
+ "Renaming output array %d after inserting dequant op %s: %s -> "
+ "%s",
+ i, LogName(*dequantize_op), model->flags.output_arrays(i),
+ dequantized_output);
model->flags.set_output_arrays(i, dequantized_output);
}
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
new file mode 100644
index 0000000000..2c8d04440f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc
@@ -0,0 +1,86 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model,
+ const FakeQuantOperator& fakequant_op) {
+ CHECK(fakequant_op.type == OperatorType::kFakeQuant);
+
+ if (!fakequant_op.minmax) {
+ // Require ReadFakeQuantMinMax to have run.
+ return false;
+ }
+
+ // FakeQuants are trivial if they are taking input from another identical
+ // FakeQuant op.
+ auto* producing_op = GetOpWithOutput(model, fakequant_op.inputs[0]);
+ if (!producing_op || producing_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ const auto& producing_fakequant_op =
+ *static_cast<FakeQuantOperator*>(producing_op);
+ if (!producing_fakequant_op.minmax) {
+ // Require ReadFakeQuantMinMax to have run.
+ return false;
+ }
+
+ if (*fakequant_op.minmax == *producing_fakequant_op.minmax &&
+ fakequant_op.num_bits == producing_fakequant_op.num_bits) {
+ transformation->AddMessageF(
+ "%s is trivial because it is preceded by an identical FakeQuant %s",
+ LogName(fakequant_op), LogName(producing_fakequant_op));
+ return true;
+ }
+
+ return false;
+}
+
+} // namespace
+
+// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc).
+bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+ if (op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(op);
+
+ if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) {
+ AddMessageF("%s is not trivial", LogName(*fakequant_op));
+ return false;
+ }
+
+ AddMessageF("Removing trivial %s", LogName(*fakequant_op));
+
+ CHECK_EQ(fakequant_op->inputs.size(), 1);
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index 625d90205a..efb7bb2184 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -45,9 +46,29 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
}
const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+
+ // Determine the final data type in the same way as PropagateFakeQuantNumBits.
+ ArrayDataType quantized_data_type = input_array.final_data_type;
+ if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op,
+ &quantized_data_type)) {
+ AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits);
+ return false;
+ }
+
+ AddMessageF("Resolving constant %s", LogName(*fakequant_op));
+
auto& output_array = model->GetArray(fakequant_op->outputs[0]);
CHECK(input_array.data_type == ArrayDataType::kFloat);
output_array.data_type = ArrayDataType::kFloat;
+
+ // We'll set the final data type to what the fake quant indicates we should
+ // have (and would have been set if this stayed around until
+ // PropagateFakeQuantNumBits).
+ if (propagate_fake_quant_num_bits()) {
+ output_array.final_data_type = quantized_data_type;
+ }
+
CHECK(!output_array.buffer);
const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
output_array.GetOrCreateMinMax() = *fakequant_op->minmax;
@@ -66,7 +87,9 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
output_buffer.data[i] = dst_val;
}
- if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+
+ if (IsDiscardableArray(*model, fakequant_op->inputs[0]) &&
+ CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
model->EraseArray(fakequant_op->inputs[0]);
}
model->operators.erase(fakequant_it);
diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
index cc7803dd86..d1d68b6b47 100644
--- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc
@@ -126,6 +126,11 @@ bool ParseTocoFlagsFromCommandLineFlags(
parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
"If true, disable fusion of known identifiable cell subgraphs into "
"cells. This includes, for example, specific forms of LSTM cell."),
+ Flag("propagate_fake_quant_num_bits",
+ parsed_flags.propagate_fake_quant_num_bits.bind(),
+ parsed_flags.propagate_fake_quant_num_bits.default_value(),
+ "If true, use FakeQuant* operator num_bits attributes to adjust "
+ "array data_types."),
};
bool asked_for_help =
*argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
@@ -211,6 +216,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
+ READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
+ READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
// Deprecated flag handling.
if (parsed_toco_flags.input_type.specified()) {
diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto
index 3237147a73..751aca948c 100644
--- a/tensorflow/contrib/lite/toco/toco_flags.proto
+++ b/tensorflow/contrib/lite/toco/toco_flags.proto
@@ -37,7 +37,7 @@ enum FileFormat {
// of as properties of models, instead describing how models are to be
// processed in the context of the present tooling job.
//
-// Next ID to use: 14.
+// Next ID to use: 15.
message TocoFlags {
// Input file format
optional FileFormat input_format = 1;
@@ -141,4 +141,13 @@ message TocoFlags {
// Disables transformations that fuse subgraphs such as known LSTMs (not all
// LSTMs are identified).
optional bool debug_disable_recurrent_cell_fusion = 13;
+
+ // Uses the FakeQuantWithMinMaxArgs.num_bits attribute to adjust quantized
+ // array data types throughout the graph. The graph must be properly annotated
+ // with FakeQuant* ops on at least the edges and may contain additional ops on
+ // the interior of the graph to widen/narrow as desired.
+ //
+ // Input and output array data types may change because of this propagation
+ // and users must be sure to query the final data_type values.
+ optional bool propagate_fake_quant_num_bits = 14;
}
diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc
index 5ba093a830..b69852453c 100644
--- a/tensorflow/contrib/lite/toco/toco_tooling.cc
+++ b/tensorflow/contrib/lite/toco/toco_tooling.cc
@@ -66,6 +66,7 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new RemoveTensorFlowIdentity);
transformations->Add(new RemoveTrivialConcatenation);
transformations->Add(new RemoveTrivialConcatenationInput);
+ transformations->Add(new RemoveTrivialFakeQuant);
transformations->Add(new RemoveTrivialSlice);
transformations->Add(new RemoveUnusedOp);
transformations->Add(new EnsureBiasVectors);
@@ -109,7 +110,6 @@ void MakeGeneralGraphTransformationsSet(
transformations->Add(new ResolveMeanAttributes);
transformations->Add(new ResolveConstantShapeOrRank);
transformations->Add(new MakeInitialDequantizeOperator);
- transformations->Add(new ResolveConstantFakeQuant);
transformations->Add(new UnpartitionEmbeddingLookup);
}
@@ -233,6 +233,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
MakeGeneralGraphTransformationsSet(&transformations);
auto* remove_trivial_reshape = new RemoveTrivialReshape;
transformations.Add(remove_trivial_reshape);
+ auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant;
+ if (quantize_output) {
+ resolve_constant_fake_quant->set_propagate_fake_quant_num_bits(
+ toco_flags.propagate_fake_quant_num_bits());
+ }
+ transformations.Add(resolve_constant_fake_quant);
if (SupportsFusedActivationFunction(output_format)) {
transformations.Add(new FuseActivationFunctions);
} else {
@@ -264,9 +270,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
RunGraphTransformations(model, "general graph transformations",
transformations);
+ // Fix any issues with IO edges. This must happen after any transform that
+ // may modify the structure of the edges.
+ FixEdgeArrays(model);
+
if (quantize_output) {
+ if (toco_flags.propagate_fake_quant_num_bits()) {
+ RunGraphTransformations(model,
+ "fake quant propagation graph transformations",
+ {new PropagateFakeQuantNumBits});
+ }
RunGraphTransformations(model, "pre-quantization graph transformations",
- {new HardcodeMinMax, new DropFakeQuant});
+ {
+ new HardcodeMinMax,
+ new DropFakeQuant,
+ });
}
if (quantize_output) {
@@ -303,10 +321,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) {
EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model);
}
- // Fix any issues with IO edges. This must happen after any transform that
- // may modify the structure of the edges.
- FixEdgeArrays(model);
-
LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model);
if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) {
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 224df9973e..ecac0c28a5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -93,9 +93,18 @@ string ArrayDataTypeName(ArrayDataType data_type) {
}
}
-bool IsInputArray(const Model& model, const string& name) {
+bool IsInputArray(const Model& model, const string& array_name) {
for (const auto& input_array : model.flags.input_arrays()) {
- if (input_array.name() == name) {
+ if (array_name == input_array.name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsOutputArray(const Model& model, const string& array_name) {
+ for (const auto& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
return true;
}
}
@@ -106,10 +115,8 @@ bool IsArrayConsumed(const Model& model, const string& name) {
if (GetOpWithInput(model, name)) {
return true;
}
- for (const string& model_output : model.flags.output_arrays()) {
- if (model_output == name) {
- return true;
- }
+ if (IsOutputArray(model, name)) {
+ return true;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (rnn_state.back_edge_source_array() == name) {
@@ -379,6 +386,7 @@ string HelpfulOperatorTypeName(const Operator& op) {
bool OperatorSupportsFusedActivation(OperatorType type) {
switch (type) {
case OperatorType::kConcatenation:
+ case OperatorType::kFakeQuant:
case OperatorType::kGather:
case OperatorType::kSlice:
case OperatorType::kSqueeze:
@@ -1064,16 +1072,38 @@ void FixEdgeArrays(Model* model) {
}
}
+namespace {
+void CopyArrayAttribs(const Array& source_array, Array* target_array) {
+ target_array->data_type = source_array.data_type;
+ target_array->final_data_type = source_array.final_data_type;
+ target_array->copy_shape(source_array.shape());
+
+ if (source_array.minmax) {
+ target_array->GetOrCreateMinMax() = source_array.GetMinMax();
+ } else {
+ target_array->minmax.reset();
+ }
+
+ if (source_array.quantization_params) {
+ target_array->GetOrCreateQuantizationParams() =
+ source_array.GetQuantizationParams();
+ } else {
+ target_array->quantization_params.reset();
+ }
+}
+} // namespace
+
void InsertCopyOperator(Model* model, const string& source_array_name,
const string& target_array_name) {
+ // Reshape to the same size. This should be a no-op.
+ const Array& source_array = model->GetArray(source_array_name);
+ std::vector<int> shape = source_array.shape().dims();
+
// Drop constant data from the target array as the copy will be done at
// runtime.
Array& target_array = model->GetOrCreateArray(target_array_name);
target_array.buffer.reset();
-
- // Reshape to the same size. This should be a no-op.
- const Array& source_array = model->GetArray(source_array_name);
- std::vector<int> shape = source_array.shape().dims();
+ CopyArrayAttribs(source_array, &target_array);
// Insert copy operator.
auto* copy_op = new TensorFlowReshapeOperator;
@@ -1089,6 +1119,7 @@ void CloneArray(Model* model, const string& source_array_name,
CHECK(!model->HasArray(target_array_name));
const Array& source_array = model->GetArray(source_array_name);
Array& target_array = model->GetOrCreateArray(target_array_name);
+ CopyArrayAttribs(source_array, &target_array);
if (source_array.minmax) {
const auto& smm = source_array.GetMinMax();
@@ -1513,14 +1544,9 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
if (model.IsOptionalArray(array_name)) return false;
// The model's input and output arrays are externally allocated.
// They are not transient arrays.
- if (IsInputArray(model, array_name)) {
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
return false;
}
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
- }
const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
@@ -1898,15 +1924,8 @@ int AxesCount(AxesOrder axes_order) {
}
bool IsDiscardableArray(const Model& model, const string& array_name) {
- for (const auto& input_array : model.flags.input_arrays()) {
- if (array_name == input_array.name()) {
- return false;
- }
- }
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
+ return false;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
@@ -1960,8 +1979,8 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
- << static_cast<int>(array.data_type) << ","
- << static_cast<int>(array.final_data_type) << ").";
+ << ArrayDataTypeName(array.data_type) << ","
+ << ArrayDataTypeName(array.final_data_type) << ").";
}
}
}
diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h
index ed0ecd4d0f..4c705f4e5f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.h
+++ b/tensorflow/contrib/lite/toco/tooling_util.h
@@ -28,7 +28,6 @@ limitations under the License.
#if TOCO_SUPPORT_PORTABLE_PROTOS
#include "third_party/protobuf/src/google/protobuf/text_format.h"
#endif // TOCO_SUPPORT_PORTABLE_PROTOS
-#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
#include "tensorflow/contrib/lite/toco/runtime/types.h"
@@ -57,7 +56,11 @@ string LogName(const Operator& op);
string ArrayDataTypeName(ArrayDataType data_type);
-bool IsInputArray(const Model& model, const string& name);
+// Returns true if the given array is specified as a model input array.
+bool IsInputArray(const Model& model, const string& array_name);
+// Returns true if the given array is specified as a model output array.
+bool IsOutputArray(const Model& model, const string& array_name);
+
bool IsArrayConsumed(const Model& model, const string& name);
int CountTrueOutputs(const Model& model, const Operator& op);
@@ -175,17 +178,6 @@ void CloneArray(Model* model, const string& source_array_name,
void ResolveModelFlags(const ModelFlags& model_flags, Model* model);
-template <ArrayDataType A>
-void GetQuantizationParamsFromMinMax(const MinMax& minmax,
- QuantizationParams* quantization_params) {
- using Integer = DataType<A>;
- const double rmin = minmax.min;
- const double rmax = minmax.max;
-
- *quantization_params =
- ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
-}
-
template <typename T>
T ConvertOperator(Operator* o, OperatorType type) {
if (o != nullptr && o->type == type) {