diff options
author | 2018-04-17 11:53:29 -0700 | |
---|---|---|
committer | 2018-04-17 11:55:26 -0700 | |
commit | d7b6cb66c0fc346cf55020042931c07208713c60 (patch) | |
tree | 9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h | |
parent | 1192c1662c5c98f55805450b4619ac2bc9c6908c (diff) |
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h index 35fb310777..79a2ce7e50 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h @@ -15,11 +15,17 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ #define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_ +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" #include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" #include "tensorflow/contrib/lite/toco/model.h" namespace toco { +// Gets the target quantized data type of an array based on the fake quant op. +// For example, if the num_bits is 8 the data type will be kUint8. +bool InferQuantizedDataTypeFromFakeQuant( + const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type); + // Gets the min/max numerical range for the given quantized data type. // For example, kUint8 will return [0,255]. // Returns true if the ranges were set and false if the type is not quantized. @@ -32,11 +38,28 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type, ArrayDataType GetQuantizedDataType(const Array& array, ArrayDataType default_type); -// Gets the quantization params for the array with the given data type and +// Returns the quantization params for the array with the given data type and // minmax. void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax, QuantizationParams* quantization_params); +// Returns the quantization params for the data type and minmax values. +template <ArrayDataType A> +void GetQuantizationParamsFromMinMax(const MinMax& minmax, + QuantizationParams* quantization_params) { + using Integer = DataType<A>; + const double rmin = minmax.min; + const double rmax = minmax.max; + *quantization_params = + ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax); +} + +// Quantizes an array by setting its data type and (if constant) quantizing +// all values in the array. +void QuantizeArray(GraphTransformation* transformation, Model* model, + const string& name, ArrayDataType quantized_data_type, + const QuantizationParams& quantization_params); + // Returns true if the given array, when quantized, contains only values between // the provided clamp min/max. // Either clamp_min or clamp_max may be +/-infinity to indicate that the value |