diff options
author | 2018-04-17 11:53:29 -0700 | |
---|---|---|
committer | 2018-04-17 11:55:26 -0700 | |
commit | d7b6cb66c0fc346cf55020042931c07208713c60 (patch) | |
tree | 9024111ebf15d12a631ffd7e176b9da7459dd5a0 | |
parent | 1192c1662c5c98f55805450b4619ac2bc9c6908c (diff) |
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
17 files changed, 702 insertions, 144 deletions
diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 5b86e4e5ae..398978b145 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -238,6 +238,7 @@ cc_library( "graph_transformations/merge_reshape_into_preceding_transpose.cc", "graph_transformations/propagate_activation_function_into_constants.cc", "graph_transformations/propagate_array_data_types.cc", + "graph_transformations/propagate_fake_quant_num_bits.cc", "graph_transformations/propagate_fixed_sizes.cc", "graph_transformations/quantization_util.cc", "graph_transformations/quantization_util.h", @@ -249,6 +250,7 @@ cc_library( "graph_transformations/remove_trivial_binary.cc", "graph_transformations/remove_trivial_concatenation.cc", "graph_transformations/remove_trivial_concatenation_input.cc", + "graph_transformations/remove_trivial_fake_quant.cc", "graph_transformations/remove_trivial_passthrough.cc", "graph_transformations/remove_trivial_passthrough.h", "graph_transformations/remove_trivial_quantized_activation_func.cc", @@ -303,7 +305,7 @@ cc_library( ":runtime", ":toco_port", ":tooling_util", - ":types_proto_cc", + "//tensorflow/contrib/lite/kernels/internal:quantization_util", "//tensorflow/core:lib", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", @@ -378,7 +380,6 @@ cc_library( ":toco_graphviz_dump_options", ":toco_port", ":types_proto_cc", - "//tensorflow/contrib/lite/kernels/internal:quantization_util", "//tensorflow/core:lib", "@com_google_absl//absl/strings", "@protobuf_archive//:protobuf_headers", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index 7a7059e357..71e7318ac3 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -237,6 +237,7 @@ struct ParsedTocoFlags { Arg<string> input_types; Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false); Arg<bool> drop_control_dependency = Arg<bool>(false); + Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false); }; } // namespace toco diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index c8352741b4..c289ddcd92 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -95,10 +95,8 @@ Color GetColorForArray(const Model& model, const string& array_name) { array_name == dump_options.graphviz_last_array) { return Color(0x9E, 0x9E, 0x9E); } - for (const string& output_array : model.flags.output_arrays()) { - if (array_name == output_array) { - return Color(0x9E, 0x9E, 0x9E); - } + if (IsOutputArray(model, array_name)) { + return Color(0x9E, 0x9E, 0x9E); } // Remaining arrays are intermediate activation arrays. // Lighter tone of the same grey as for input/output arrays: @@ -119,6 +117,12 @@ void AppendArrayVal(string* string, Array const& array, int index) { return; } AppendF(string, "%d", data[index]); + } else if (array.buffer->type == ArrayDataType::kInt16) { + const auto& data = array.GetBuffer<ArrayDataType::kInt16>().data; + if (index >= data.size()) { + return; + } + AppendF(string, "%d", data[index]); } else if (array.buffer->type == ArrayDataType::kInt32) { const auto& data = array.GetBuffer<ArrayDataType::kInt32>().data; if (index >= data.size()) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc index badefeca88..708ecf6e0a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc @@ -47,7 +47,7 @@ bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) { op->type == OperatorType::kDepthwiseConv || op->type == OperatorType::kFullyConnected) { if (ProcessLinearOperator(model, op)) { - AddMessageF("Added bias vector to %s", LogName(*op)); + AddMessageF("Added bias vector to %s as %s", LogName(*op), op->inputs[2]); return true; } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index dbf029a853..56b3dec5c4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -135,6 +135,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(PropagateActivationFunctionIntoConstants) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) +DECLARE_GRAPH_TRANSFORMATION(PropagateFakeQuantNumBits); DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax) DECLARE_GRAPH_TRANSFORMATION(Quantize) @@ -144,6 +145,7 @@ DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput) +DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialFakeQuant) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialSlice) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc) DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedMinMax) @@ -163,7 +165,6 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge) DECLARE_GRAPH_TRANSFORMATION(ResolveSqueezeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch) DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile) -DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantReshape) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTranspose) @@ -210,6 +211,23 @@ class RemoveTrivialReshape : public GraphTransformation { bool treat_expand_dims_as_trivial_ = false; }; +class ResolveConstantFakeQuant : public GraphTransformation { + public: + bool Run(Model* model, std::size_t op_index) override; + const char* Name() const override { return "ResolveConstantFakeQuant"; } + + // True if the num_bits should adjust the final data type. + bool propagate_fake_quant_num_bits() const { + return propagate_fake_quant_num_bits_; + } + void set_propagate_fake_quant_num_bits(bool val) { + propagate_fake_quant_num_bits_ = val; + } + + private: + bool propagate_fake_quant_num_bits_ = false; +}; + #undef DECLARE_GRAPH_TRANSFORMATION } // end namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 183b3d3f2e..45d9f73a1e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -18,6 +18,7 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc new file mode 100644 index 0000000000..0bce183c18 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -0,0 +1,307 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +void ChangeArrayDataType(GraphTransformation* transformation, Array* array, + ArrayDataType new_data_type, + const MinMax* new_minmax) { + // Ensure the array ends up in the new type (if it hasn't yet been quantized). + array->final_data_type = new_data_type; + + if (array->minmax && array->quantization_params) { + // The array is already quantized and has min/max info. + // As we are changing the data type we need to fix up the existing min/max + // to the new data type range. + + double old_quantized_min, old_quantized_max; + CHECK(GetQuantizedDataTypeNumericalRange( + array->data_type, &old_quantized_min, &old_quantized_max)) + << "Existing data type is not quantized: " + << ArrayDataTypeName(array->data_type); + double new_quantized_min, new_quantized_max; + CHECK(GetQuantizedDataTypeNumericalRange(new_data_type, &new_quantized_min, + &new_quantized_max)) + << "New data type is not quantized: " + << ArrayDataTypeName(new_data_type); + + // Compute new minmax values. + double min = (old_quantized_min - array->quantization_params->zero_point) * + array->quantization_params->scale; + double max = + (old_quantized_max + 1 - array->quantization_params->zero_point) * + array->quantization_params->scale; + max = max - 1.0 / (new_quantized_max + 1); + + auto& array_minmax = array->GetOrCreateMinMax(); + transformation->AddMessageF( + "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min, + array_minmax.max, ArrayDataTypeName(array->data_type), min, max, + ArrayDataTypeName(new_data_type)); + + array_minmax.min = min; + array_minmax.max = max; + GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>( + array_minmax, array->quantization_params.get()); + + // Directly change the type as the array was already quantized. + array->data_type = new_data_type; + } else { + // Array has not yet been quantized so we can just set the final data type + // and assign the new min/max value (if provided). + CHECK(!array->quantization_params); + + if (!array->minmax && new_minmax) { + transformation->AddMessageF("Forcing new minmax to %g,%g (%s)", + new_minmax->min, new_minmax->max, + ArrayDataTypeName(new_data_type)); + auto& array_minmax = array->GetOrCreateMinMax(); + array_minmax.min = new_minmax->min; + array_minmax.max = new_minmax->max; + } + } +} + +// Returns true if the op blocks our backward recursive data type propagation. +bool DoesOpBlockBackwardPropagation(const Operator& op) { + switch (op.type) { + case OperatorType::kConcatenation: + case OperatorType::kTensorFlowConcat: + case OperatorType::kTensorFlowConcatV2: + // Concat shouldn't block propagation, but we do expect that all inputs + // have the same range. + return false; + case OperatorType::kDequantize: + // Dequantize ops are inserted between the value we care about and the + // FakeQuant so make sure we move across them. + case OperatorType::kGather: + // Gathers need their parameters changed to the appropriate data type. + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + // Reshapes and transposes don't change values. + return false; + default: + return true; + } +} + +// Returns true if the input of an op blocks our backward recursive data type +// propagation. +bool DoesOpInputBlockBackwardPropagation(const Operator& op, int input_index) { + switch (op.type) { + case OperatorType::kGather: + // Ignore gather indices. + return input_index != 0; + break; + case OperatorType::kTensorFlowReshape: + case OperatorType::kTranspose: + // Ignore reshape/transpose shapes/dimensions. + return input_index != 0; + default: + return false; + } +} + +// Propagates the data type up into the input arrays if they are model inputs +// that may need their type changed. May act recursively if the inputs are +// produced by ops that we can move over (such as Dequantize). +bool RecursivelyBackwardPropagateDataType(GraphTransformation* transformation, + Model* model, Operator* op, + ArrayDataType new_data_type, + const MinMax& new_minmax) { + bool did_change = false; + for (int input_index = 0; input_index < op->inputs.size(); ++input_index) { + const auto& input = op->inputs[input_index]; + auto& input_array = model->GetArray(input); + if (input_array.final_data_type == new_data_type) { + // Final data type is already - skip. + continue; + } + + // Prevent moving into constant param args that we don't want to modify. + if (DoesOpInputBlockBackwardPropagation(*op, input_index)) { + continue; + } + + if (input_array.final_data_type != new_data_type) { + transformation->AddMessageF( + "Adjusting input final data type of array %s from %s to %s", input, + ArrayDataTypeName(input_array.final_data_type), + ArrayDataTypeName(new_data_type)); + did_change = true; + ChangeArrayDataType(transformation, &input_array, new_data_type, + &new_minmax); + + // Walk up into all ops producing the inputs to this op. + for (auto& producing_op : model->operators) { + if (!DoesOpBlockBackwardPropagation(*producing_op)) { + for (const auto& output : producing_op->outputs) { + if (input == output) { + did_change |= RecursivelyBackwardPropagateDataType( + transformation, model, producing_op.get(), new_data_type, + new_minmax); + } + } + } + } + } + } + return did_change; +} + +// Returns true if the op blocks our forward recursive data type propagation. +bool DoesOpBlockForwardPropagation(const Operator& op) { + switch (op.type) { + case OperatorType::kFakeQuant: + // Always stop at another FakeQuant, as it will likely have different + // parameters. + return true; + default: + return false; + } +} + +// Recurses down the graph setting the data type of all arrays until an operator +// that blocks propagation (like another FakeQuant) or a final_data_type is +// already specified. +bool RecursivelyForwardPropagateDataType(GraphTransformation* transformation, + Model* model, Operator* op, + ArrayDataType new_data_type) { + bool did_change = false; + for (const auto& output : op->outputs) { + auto& output_array = model->GetArray(output); + if (output_array.final_data_type == new_data_type) { + // Final data type is already - skip. + continue; + } + + if (output_array.final_data_type == ArrayDataType::kNone || + output_array.final_data_type != new_data_type) { + transformation->AddMessageF( + "Adjusting output final data type of array %s from %s to %s", output, + ArrayDataTypeName(output_array.final_data_type), + ArrayDataTypeName(new_data_type)); + did_change = true; + ChangeArrayDataType(transformation, &output_array, new_data_type, + nullptr); + + // Walk down into all ops consuming the output of this op. + for (auto& consuming_op : model->operators) { + if (!DoesOpBlockForwardPropagation(*consuming_op)) { + for (const auto& input : consuming_op->inputs) { + if (input == output) { + did_change |= RecursivelyForwardPropagateDataType( + transformation, model, consuming_op.get(), new_data_type); + } + } + } + } + } + } + return did_change; +} + +} // namespace + +// Propagates the num_bits on a FakeQuant operator into the final data types +// of inputs and outputs. For example, if FakeQuant.num_bits==16 then we know +// the output must be int16 and assume all inputs up until the preceding op are +// also 16. +// +// This can be thought of as a bidirectional flood-fill of the num_bits implied +// final_data_type that terminates at other FakeQuant ops (and a few others as +// determined by DoesOpBlockBackwardPropagation/DoesOpBlockForwardPropagation). +// Once all FakeQuant ops have been visted the arrays should all have +// appropriate final_data_types if the source graph was annotated with the +// proper FakeQuant ops. +// +// Annotating a graph requires following a few hard rules: +// - every input MUST have a FakeQuant immediately following it +// - every output MUST have a FakeQuant immediately preceding it +// - important arithmetic ops (such as FullyConnected) SHOULD have a FakeQuant +// immediately following it +// - all trained weights (RHS of FullyConnected ops, params on Gather ops, etc) +// MUST have FakeQuants between them and the consuming op +// Additional FakeQuants may be used if desired, especially in areas that may +// suffer from large precision changes - such as between a Softmax and a +// FullyConnected. Only by validating accuracy differences between float +// inference with the FakeQuant ops simulating quantization and the actually +// quantized graph can you be sure the appropriate FakeQuant ops are present. +// +// You can tell if you're missing some FakeQuants by looking for warnings from +// quantize.cc about minmax ranges being determined by the contents of constant +// arrays. This will almost never produce functional models during inference. +// +// As this op may change the data types and ranges of input and output arrays +// downstream tools must also be sure to parse the output model flags to get the +// post-Transform values that may have changed due to this transformation. +// +// This isn't a GraphTransformation in the traditional respect as it affects ops +// outside of the one under transformation. This is primarily so that we can +// utilize the graph traversal and repeated pass system underlying the +// transformation system to exhaustively find all FakeQuant ops. It also gets us +// nice logging and integration with the graphviz video dumping mode. +// In general you should not copy this style of transformation and stick to +// local-only changes as seen in the other transformations. +bool PropagateFakeQuantNumBits::Run(Model* model, std::size_t op_index) { + auto it = model->operators.begin() + op_index; + auto* op = it->get(); + if (op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast<FakeQuantOperator*>(op); + + ArrayDataType quantized_data_type = ArrayDataType::kNone; + if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, + &quantized_data_type)) { + AddMessageF("FakeQuant op %s num_bits=%d is out of range, ignoring", + LogName(*op), fakequant_op->num_bits); + return false; + } + const auto& final_minmax = *fakequant_op->minmax; + + AddMessageF( + "Beginning propagation of fake quant %s num_bits=%d min=%g max=%g to %s", + LogName(*op), fakequant_op->num_bits, final_minmax.min, final_minmax.max, + ArrayDataTypeName(quantized_data_type)); + + bool did_change = false; + + // Propagate the FakeQuant information backward up the graph. + // This will possibly adjust input arrays or constant types (like Gather). + did_change |= RecursivelyBackwardPropagateDataType( + this, model, op, quantized_data_type, final_minmax); + + // Propagate the FakeQuant information forward down the graph. + // This will possibly adjust output arrays. + did_change |= + RecursivelyForwardPropagateDataType(this, model, op, quantized_data_type); + + return did_change; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc index e080df4bed..d74cad9a62 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc @@ -22,6 +22,20 @@ limitations under the License. namespace toco { +bool InferQuantizedDataTypeFromFakeQuant( + const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type) { + if (op.num_bits <= 8) { + *out_quantized_data_type = ArrayDataType::kUint8; + return true; + } else if (op.num_bits <= 16) { + *out_quantized_data_type = ArrayDataType::kInt16; + return true; + } else { + *out_quantized_data_type = ArrayDataType::kNone; + return false; + } +} + bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type, double* out_min_value, double* out_max_value) { @@ -103,6 +117,80 @@ void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, } } +namespace { + +template <ArrayDataType A> +std::unique_ptr<GenericBuffer> QuantizeBuffer( + const GenericBuffer& buffer, + const QuantizationParams& quantization_params) { + const auto inverse_scale = 1. / quantization_params.scale; + CHECK(buffer.type == ArrayDataType::kFloat); + const auto& float_buffer = + static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer); + auto* quantized_buffer = new Buffer<A>; + quantized_buffer->data.resize(float_buffer.data.size()); + for (std::size_t i = 0; i < float_buffer.data.size(); i++) { + const float src_val = float_buffer.data[i]; + double scaled_val; // Astonishingly, using 'float' degrades accuracy just + // enough to make a few tests fail! + if (quantization_params.scale == 0) { + CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, " + << "so all its values should be 0."; + scaled_val = quantization_params.zero_point; + } else { + scaled_val = quantization_params.zero_point + inverse_scale * src_val; + } + quantized_buffer->data[i] = + tflite::SafeCast<DataType<A>>(std::round(scaled_val)); + } + return std::unique_ptr<GenericBuffer>(quantized_buffer); +} + +template <ArrayDataType A> +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, + const QuantizationParams& quantization_params) { + auto& array = model->GetArray(name); + CHECK(array.data_type == ArrayDataType::kFloat); + CHECK(!array.quantization_params); + array.GetOrCreateQuantizationParams() = quantization_params; + if (array.buffer) { + array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params); + } + array.data_type = A; + array.final_data_type = A; + transformation->AddMessageF( + "Quantized array %s to %s zero_point=%g, scale=%g", name, + ArrayDataTypeName(array.data_type), quantization_params.zero_point, + quantization_params.scale); +} + +} // namespace + +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params) { + ArrayDataType adjusted_data_type = quantized_data_type; + auto& array = model->GetArray(name); + if (array.final_data_type == ArrayDataType::kInt16) { + adjusted_data_type = array.final_data_type; + } + + switch (adjusted_data_type) { + case ArrayDataType::kUint8: + return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name, + quantization_params); + case ArrayDataType::kInt16: + return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name, + quantization_params); + case ArrayDataType::kInt32: + return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name, + quantization_params); + default: + LOG(FATAL) << "Unhandled case."; + } +} + bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation, const Array& array, double clamp_min, double clamp_max) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h index 35fb310777..79a2ce7e50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h @@ -15,11 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" namespace toco { +// Gets the target quantized data type of an array based on the fake quant op. +// For example, if the num_bits is 8 the data type will be kUint8. +bool InferQuantizedDataTypeFromFakeQuant( + const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type); + // Gets the min/max numerical range for the given quantized data type. // For example, kUint8 will return [0,255]. // Returns true if the ranges were set and false if the type is not quantized. @@ -32,11 +38,28 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type, ArrayDataType GetQuantizedDataType(const Array& array, ArrayDataType default_type); -// Gets the quantization params for the array with the given data type and +// Returns the quantization params for the array with the given data type and // minmax. void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, QuantizationParams* quantization_params); +// Returns the quantization params for the data type and minmax values. +template <ArrayDataType A> +void GetQuantizationParamsFromMinMax(const MinMax& minmax, + QuantizationParams* quantization_params) { + using Integer = DataType<A>; + const double rmin = minmax.min; + const double rmax = minmax.max; + *quantization_params = + ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax); +} + +// Quantizes an array by setting its data type and (if constant) quantizing +// all values in the array. +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params); + // Returns true if the given array, when quantized, contains only values between // the provided clamp min/max. // Either clamp_min or clamp_max may be +/-infinity to indicate that the value diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index d6cae3cdbf..fa46e6bc38 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -57,72 +57,6 @@ bool SupportsQuantization(const Operator& op) { type == OperatorType::kTranspose || type == OperatorType::kMean; } -template <ArrayDataType A> -std::unique_ptr<GenericBuffer> QuantizeBuffer( - const GenericBuffer& buffer, - const QuantizationParams& quantization_params) { - const auto inverse_scale = 1. / quantization_params.scale; - CHECK(buffer.type == ArrayDataType::kFloat); - const auto& float_buffer = - static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer); - auto* quantized_buffer = new Buffer<A>; - quantized_buffer->data.resize(float_buffer.data.size()); - for (std::size_t i = 0; i < float_buffer.data.size(); i++) { - const float src_val = float_buffer.data[i]; - double scaled_val; // Astonishingly, using 'float' degrades accuracy just - // enough to make a few tests fail! - if (quantization_params.scale == 0) { - CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, " - << "so all its values should be 0."; - scaled_val = quantization_params.zero_point; - } else { - scaled_val = quantization_params.zero_point + inverse_scale * src_val; - } - quantized_buffer->data[i] = - tflite::SafeCast<DataType<A>>(std::round(scaled_val)); - } - return std::unique_ptr<GenericBuffer>(quantized_buffer); -} - -template <ArrayDataType A> -void QuantizeArray(GraphTransformation* transformation, Model* model, - const string& name, - const QuantizationParams& quantization_params) { - auto& array = model->GetArray(name); - CHECK(array.data_type == ArrayDataType::kFloat); - CHECK(!array.quantization_params); - array.GetOrCreateQuantizationParams() = quantization_params; - if (array.buffer) { - array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params); - } - array.data_type = A; - transformation->AddMessageF("Quantized array %s", name); -} - -void QuantizeArray(GraphTransformation* transformation, Model* model, - const string& name, ArrayDataType quantized_data_type, - const QuantizationParams& quantization_params) { - ArrayDataType adjusted_data_type = quantized_data_type; - auto& array = model->GetArray(name); - if (array.final_data_type == ArrayDataType::kInt16) { - adjusted_data_type = array.final_data_type; - } - - switch (adjusted_data_type) { - case ArrayDataType::kUint8: - return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name, - quantization_params); - case ArrayDataType::kInt16: - return QuantizeArray<ArrayDataType::kInt16>(transformation, model, name, - quantization_params); - case ArrayDataType::kInt32: - return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name, - quantization_params); - default: - LOG(FATAL) << "Unhandled case."; - } -} - const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { auto& array = model->GetArray(array_name); // Normally we should have a MinMax recorded on this Array, @@ -245,6 +179,8 @@ bool ChooseQuantizationForOperatorInput( const auto& input_weights = model->GetArray(op.inputs[weights_input_index]); if (!input_activations.quantization_params || !input_weights.quantization_params) { + transformation->AddMessageF( + "Input array %s is a bias vector but has no qparams", input); return false; } const auto input_activations_scale = @@ -366,6 +302,9 @@ bool ChooseQuantizationForOperatorOutput( const auto& output = op.outputs[output_index]; auto& array = model->GetArray(output); if (array.data_type != ArrayDataType::kFloat) { + transformation->AddMessageF("Array data type already set to %s, final=%s", + ArrayDataTypeName(array.data_type), + ArrayDataTypeName(array.final_data_type)); return false; } *quantized_data_type = model->GetArray(op.inputs[0]).data_type; @@ -427,29 +366,22 @@ bool ChooseQuantizationForOperatorOutput( // Fixes array minmax info to match the quantization parameters. // This is required for when quantization parameters change for an array during // quantization (such as ChooseQuantizationForOperatorOutput). -void FixMinMaxPostQuantization(ArrayDataType quantized_data_type, +void FixMinMaxPostQuantization(GraphTransformation* transformation, + ArrayDataType quantized_data_type, const QuantizationParams& quantization_params, MinMax* minmax) { - double qmin, qmax; - switch (quantized_data_type) { - case ArrayDataType::kUint8: - qmin = 0; - qmax = 255; - break; - case ArrayDataType::kInt16: - qmin = -32768; - qmax = 32767; - break; - default: - // No update required. - return; + double quantized_min, quantized_max; + if (!GetQuantizedDataTypeNumericalRange(quantized_data_type, &quantized_min, + &quantized_max)) { + // Not quantized - no update required. + return; } // Compute new minmax values. - double min = - (qmin - quantization_params.zero_point) * quantization_params.scale; - double max = - (qmax - quantization_params.zero_point) * quantization_params.scale; + double min = (quantized_min - quantization_params.zero_point) * + quantization_params.scale; + double max = (quantized_max - quantization_params.zero_point) * + quantization_params.scale; // If we are close to the existing minmax values don't bother changing them. // This prevents propagating small floating point precision errors. @@ -457,6 +389,9 @@ void FixMinMaxPostQuantization(ArrayDataType quantized_data_type, const double width = max - min; if (std::abs(min - minmax->min) > kMinMaxThreshold * width || std::abs(max - minmax->max) > kMinMaxThreshold * width) { + transformation->AddMessageF( + "Adjusting min/max from %g,%g to %g,%g to match quantization params", + minmax->min, minmax->max, min, max); minmax->min = min; minmax->max = max; } @@ -566,10 +501,33 @@ bool Quantize::Run(Model* model, std::size_t op_index) { // input instead. for (int i = 0; i < model->flags.output_arrays_size(); i++) { if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) { - model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + // TODO(b/78013785): never rename output arrays. + if (IsInputArray(*model, dequantize_op->inputs[0])) { + // The op input is an input array and the output is an output + // array and we can't have an array be both. Insert a copy + // op to ensure the two arrays stay separate. + AddMessageF( + "Tried to rename output array %d while removing dequant " + "op %s but array is also an input; inserting copy %s " + "-> %s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantize_op->inputs[0]); + InsertCopyOperator(model, dequantize_op->inputs[0], + dequantize_op->outputs[0]); + } else { + // Op output is strictly used as an output array, so we can + // just rename the array and directly bypass the op. + AddMessageF( + "Renaming output array %d after removing dequant op %s: " + "%s -> %s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantize_op->inputs[0]); + model->flags.set_output_arrays(i, dequantize_op->inputs[0]); + model->EraseArray(dequantize_op->outputs[0]); + } + break; } } - model->EraseArray(dequantize_op->outputs[0]); model->operators.erase(dequantize_it); } changed = true; @@ -615,7 +573,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { CHECK(output_array.minmax) << "Output array named " << output << " lacks minmax"; auto& output_minmax = output_array.GetMinMax(); - FixMinMaxPostQuantization(quantized_data_type, quantization_params, + FixMinMaxPostQuantization(this, quantized_data_type, quantization_params, &output_minmax); QuantizeArray(this, model, output, quantized_data_type, @@ -626,6 +584,7 @@ bool Quantize::Run(Model* model, std::size_t op_index) { auto& dequantized_output_array = model->GetOrCreateArray(dequantized_output); dequantized_output_array.data_type = ArrayDataType::kFloat; + dequantized_output_array.final_data_type = output_array.data_type; auto& dequantized_output_minmax = dequantized_output_array.GetOrCreateMinMax(); dequantized_output_minmax.min = output_minmax.min; @@ -642,6 +601,12 @@ bool Quantize::Run(Model* model, std::size_t op_index) { dequantize_op->outputs = {dequantized_output}; for (int i = 0; i < model->flags.output_arrays_size(); i++) { if (model->flags.output_arrays(i) == output) { + // TODO(b/78013785): never rename output arrays. + AddMessageF( + "Renaming output array %d after inserting dequant op %s: %s -> " + "%s", + i, LogName(*dequantize_op), model->flags.output_arrays(i), + dequantized_output); model->flags.set_output_arrays(i, dequantized_output); } } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc new file mode 100644 index 0000000000..2c8d04440f --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_fake_quant.cc @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <iterator> +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +bool IsFakeQuantTrivial(GraphTransformation* transformation, const Model& model, + const FakeQuantOperator& fakequant_op) { + CHECK(fakequant_op.type == OperatorType::kFakeQuant); + + if (!fakequant_op.minmax) { + // Require ReadFakeQuantMinMax to have run. + return false; + } + + // FakeQuants are trivial if they are taking input from another identical + // FakeQuant op. + auto* producing_op = GetOpWithOutput(model, fakequant_op.inputs[0]); + if (!producing_op || producing_op->type != OperatorType::kFakeQuant) { + return false; + } + const auto& producing_fakequant_op = + *static_cast<FakeQuantOperator*>(producing_op); + if (!producing_fakequant_op.minmax) { + // Require ReadFakeQuantMinMax to have run. + return false; + } + + if (*fakequant_op.minmax == *producing_fakequant_op.minmax && + fakequant_op.num_bits == producing_fakequant_op.num_bits) { + transformation->AddMessageF( + "%s is trivial because it is preceded by an identical FakeQuant %s", + LogName(fakequant_op), LogName(producing_fakequant_op)); + return true; + } + + return false; +} + +} // namespace + +// Removes FakeQuant ops that are trivial (have no effect, are redundant, etc). +bool RemoveTrivialFakeQuant::Run(Model* model, std::size_t op_index) { + const auto op_it = model->operators.begin() + op_index; + auto* op = op_it->get(); + if (op->type != OperatorType::kFakeQuant) { + return false; + } + auto* fakequant_op = static_cast<FakeQuantOperator*>(op); + + if (!IsFakeQuantTrivial(this, *model, *fakequant_op)) { + AddMessageF("%s is not trivial", LogName(*fakequant_op)); + return false; + } + + AddMessageF("Removing trivial %s", LogName(*fakequant_op)); + + CHECK_EQ(fakequant_op->inputs.size(), 1); + return RemoveTrivialPassthroughOp(this, model, op_index); +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index 625d90205a..efb7bb2184 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -18,6 +18,7 @@ limitations under the License. #include <vector> #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/tooling_util.h" #include "tensorflow/core/platform/logging.h" @@ -45,9 +46,29 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { } const auto& input_array = model->GetArray(fakequant_op->inputs[0]); + CHECK(input_array.data_type == ArrayDataType::kFloat); + + // Determine the final data type in the same way as PropagateFakeQuantNumBits. + ArrayDataType quantized_data_type = input_array.final_data_type; + if (!InferQuantizedDataTypeFromFakeQuant(*fakequant_op, + &quantized_data_type)) { + AddMessageF("Unsupported FakeQuant num_bits=%d", fakequant_op->num_bits); + return false; + } + + AddMessageF("Resolving constant %s", LogName(*fakequant_op)); + auto& output_array = model->GetArray(fakequant_op->outputs[0]); CHECK(input_array.data_type == ArrayDataType::kFloat); output_array.data_type = ArrayDataType::kFloat; + + // We'll set the final data type to what the fake quant indicates we should + // have (and would have been set if this stayed around until + // PropagateFakeQuantNumBits). + if (propagate_fake_quant_num_bits()) { + output_array.final_data_type = quantized_data_type; + } + CHECK(!output_array.buffer); const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>(); output_array.GetOrCreateMinMax() = *fakequant_op->minmax; @@ -66,7 +87,9 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); output_buffer.data[i] = dst_val; } - if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { + + if (IsDiscardableArray(*model, fakequant_op->inputs[0]) && + CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) { model->EraseArray(fakequant_op->inputs[0]); } model->operators.erase(fakequant_it); diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index cc7803dd86..d1d68b6b47 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -126,6 +126,11 @@ bool ParseTocoFlagsFromCommandLineFlags( parsed_flags.debug_disable_recurrent_cell_fusion.default_value(), "If true, disable fusion of known identifiable cell subgraphs into " "cells. This includes, for example, specific forms of LSTM cell."), + Flag("propagate_fake_quant_num_bits", + parsed_flags.propagate_fake_quant_num_bits.bind(), + parsed_flags.propagate_fake_quant_num_bits.default_value(), + "If true, use FakeQuant* operator num_bits attributes to adjust " + "array data_types."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); @@ -211,6 +216,8 @@ void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags, READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone); READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone); READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone); + READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone); + READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone); // Deprecated flag handling. if (parsed_toco_flags.input_type.specified()) { diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 3237147a73..751aca948c 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -37,7 +37,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 14. +// Next ID to use: 15. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -141,4 +141,13 @@ message TocoFlags { // Disables transformations that fuse subgraphs such as known LSTMs (not all // LSTMs are identified). optional bool debug_disable_recurrent_cell_fusion = 13; + + // Uses the FakeQuantWithMinMaxArgs.num_bits attribute to adjust quantized + // array data types throughout the graph. The graph must be properly annotated + // with FakeQuant* ops on at least the edges and may contain additional ops on + // the interior of the graph to widen/narrow as desired. + // + // Input and output array data types may change because of this propagation + // and users must be sure to query the final data_type values. + optional bool propagate_fake_quant_num_bits = 14; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 5ba093a830..b69852453c 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -66,6 +66,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new RemoveTensorFlowIdentity); transformations->Add(new RemoveTrivialConcatenation); transformations->Add(new RemoveTrivialConcatenationInput); + transformations->Add(new RemoveTrivialFakeQuant); transformations->Add(new RemoveTrivialSlice); transformations->Add(new RemoveUnusedOp); transformations->Add(new EnsureBiasVectors); @@ -109,7 +110,6 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveMeanAttributes); transformations->Add(new ResolveConstantShapeOrRank); transformations->Add(new MakeInitialDequantizeOperator); - transformations->Add(new ResolveConstantFakeQuant); transformations->Add(new UnpartitionEmbeddingLookup); } @@ -233,6 +233,12 @@ void Transform(const TocoFlags& toco_flags, Model* model) { MakeGeneralGraphTransformationsSet(&transformations); auto* remove_trivial_reshape = new RemoveTrivialReshape; transformations.Add(remove_trivial_reshape); + auto* resolve_constant_fake_quant = new ResolveConstantFakeQuant; + if (quantize_output) { + resolve_constant_fake_quant->set_propagate_fake_quant_num_bits( + toco_flags.propagate_fake_quant_num_bits()); + } + transformations.Add(resolve_constant_fake_quant); if (SupportsFusedActivationFunction(output_format)) { transformations.Add(new FuseActivationFunctions); } else { @@ -264,9 +270,21 @@ void Transform(const TocoFlags& toco_flags, Model* model) { RunGraphTransformations(model, "general graph transformations", transformations); + // Fix any issues with IO edges. This must happen after any transform that + // may modify the structure of the edges. + FixEdgeArrays(model); + if (quantize_output) { + if (toco_flags.propagate_fake_quant_num_bits()) { + RunGraphTransformations(model, + "fake quant propagation graph transformations", + {new PropagateFakeQuantNumBits}); + } RunGraphTransformations(model, "pre-quantization graph transformations", - {new HardcodeMinMax, new DropFakeQuant}); + { + new HardcodeMinMax, + new DropFakeQuant, + }); } if (quantize_output) { @@ -303,10 +321,6 @@ void Transform(const TocoFlags& toco_flags, Model* model) { EncodeConstantArraysMinMaxByWrappingThemInFakeQuantNodes(model); } - // Fix any issues with IO edges. This must happen after any transform that - // may modify the structure of the edges. - FixEdgeArrays(model); - LogDump(kLogLevelModelChanged, "AFTER TRANSFORMATIONS", *model); if (output_format != GRAPHVIZ_DOT && output_format != TFLITE) { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 224df9973e..ecac0c28a5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -93,9 +93,18 @@ string ArrayDataTypeName(ArrayDataType data_type) { } } -bool IsInputArray(const Model& model, const string& name) { +bool IsInputArray(const Model& model, const string& array_name) { for (const auto& input_array : model.flags.input_arrays()) { - if (input_array.name() == name) { + if (array_name == input_array.name()) { + return true; + } + } + return false; +} + +bool IsOutputArray(const Model& model, const string& array_name) { + for (const auto& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { return true; } } @@ -106,10 +115,8 @@ bool IsArrayConsumed(const Model& model, const string& name) { if (GetOpWithInput(model, name)) { return true; } - for (const string& model_output : model.flags.output_arrays()) { - if (model_output == name) { - return true; - } + if (IsOutputArray(model, name)) { + return true; } for (const auto& rnn_state : model.flags.rnn_states()) { if (rnn_state.back_edge_source_array() == name) { @@ -379,6 +386,7 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { case OperatorType::kConcatenation: + case OperatorType::kFakeQuant: case OperatorType::kGather: case OperatorType::kSlice: case OperatorType::kSqueeze: @@ -1064,16 +1072,38 @@ void FixEdgeArrays(Model* model) { } } +namespace { +void CopyArrayAttribs(const Array& source_array, Array* target_array) { + target_array->data_type = source_array.data_type; + target_array->final_data_type = source_array.final_data_type; + target_array->copy_shape(source_array.shape()); + + if (source_array.minmax) { + target_array->GetOrCreateMinMax() = source_array.GetMinMax(); + } else { + target_array->minmax.reset(); + } + + if (source_array.quantization_params) { + target_array->GetOrCreateQuantizationParams() = + source_array.GetQuantizationParams(); + } else { + target_array->quantization_params.reset(); + } +} +} // namespace + void InsertCopyOperator(Model* model, const string& source_array_name, const string& target_array_name) { + // Reshape to the same size. This should be a no-op. + const Array& source_array = model->GetArray(source_array_name); + std::vector<int> shape = source_array.shape().dims(); + // Drop constant data from the target array as the copy will be done at // runtime. Array& target_array = model->GetOrCreateArray(target_array_name); target_array.buffer.reset(); - - // Reshape to the same size. This should be a no-op. - const Array& source_array = model->GetArray(source_array_name); - std::vector<int> shape = source_array.shape().dims(); + CopyArrayAttribs(source_array, &target_array); // Insert copy operator. auto* copy_op = new TensorFlowReshapeOperator; @@ -1089,6 +1119,7 @@ void CloneArray(Model* model, const string& source_array_name, CHECK(!model->HasArray(target_array_name)); const Array& source_array = model->GetArray(source_array_name); Array& target_array = model->GetOrCreateArray(target_array_name); + CopyArrayAttribs(source_array, &target_array); if (source_array.minmax) { const auto& smm = source_array.GetMinMax(); @@ -1513,14 +1544,9 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { if (model.IsOptionalArray(array_name)) return false; // The model's input and output arrays are externally allocated. // They are not transient arrays. - if (IsInputArray(model, array_name)) { + if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { return false; } - for (const string& output_array : model.flags.output_arrays()) { - if (array_name == output_array) { - return false; - } - } const auto& array = &model.GetArray(array_name); // An array with a constant buffer isn't a transient array. if (!!array->buffer) { @@ -1898,15 +1924,8 @@ int AxesCount(AxesOrder axes_order) { } bool IsDiscardableArray(const Model& model, const string& array_name) { - for (const auto& input_array : model.flags.input_arrays()) { - if (array_name == input_array.name()) { - return false; - } - } - for (const string& output_array : model.flags.output_arrays()) { - if (array_name == output_array) { - return false; - } + if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { + return false; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { @@ -1960,8 +1979,8 @@ void CheckFinalDataTypesSatisfied(const Model& model) { CHECK(array.final_data_type == array.data_type) << "Array \"" << array_entry.first << "\" has mis-matching actual and final data types (" - << static_cast<int>(array.data_type) << "," - << static_cast<int>(array.final_data_type) << ")."; + << ArrayDataTypeName(array.data_type) << "," + << ArrayDataTypeName(array.final_data_type) << ")."; } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index ed0ecd4d0f..4c705f4e5f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -28,7 +28,6 @@ limitations under the License. #if TOCO_SUPPORT_PORTABLE_PROTOS #include "third_party/protobuf/src/google/protobuf/text_format.h" #endif // TOCO_SUPPORT_PORTABLE_PROTOS -#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/toco/model.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/runtime/types.h" @@ -57,7 +56,11 @@ string LogName(const Operator& op); string ArrayDataTypeName(ArrayDataType data_type); -bool IsInputArray(const Model& model, const string& name); +// Returns true if the given array is specified as a model input array. +bool IsInputArray(const Model& model, const string& array_name); +// Returns true if the given array is specified as a model output array. +bool IsOutputArray(const Model& model, const string& array_name); + bool IsArrayConsumed(const Model& model, const string& name); int CountTrueOutputs(const Model& model, const Operator& op); @@ -175,17 +178,6 @@ void CloneArray(Model* model, const string& source_array_name, void ResolveModelFlags(const ModelFlags& model_flags, Model* model); -template <ArrayDataType A> -void GetQuantizationParamsFromMinMax(const MinMax& minmax, - QuantizationParams* quantization_params) { - using Integer = DataType<A>; - const double rmin = minmax.min; - const double rmax = minmax.max; - - *quantization_params = - ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax); -} - template <typename T> T ConvertOperator(Operator* o, OperatorType type) { if (o != nullptr && o->type == type) { |