aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 11:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 11:55:26 -0700
commitd7b6cb66c0fc346cf55020042931c07208713c60 (patch)
tree9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
parent1192c1662c5c98f55805450b4619ac2bc9c6908c (diff)
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
index 35fb310777..79a2ce7e50 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.h
@@ -15,11 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
#define TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_QUANTIZATION_UTIL_H_
+#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h"
#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
namespace toco {
+// Gets the target quantized data type of an array based on the fake quant op.
+// For example, if the num_bits is 8 the data type will be kUint8.
+bool InferQuantizedDataTypeFromFakeQuant(
+ const FakeQuantOperator& op, ArrayDataType* out_quantized_data_type);
+
// Gets the min/max numerical range for the given quantized data type.
// For example, kUint8 will return [0,255].
// Returns true if the ranges were set and false if the type is not quantized.
@@ -32,11 +38,28 @@ bool GetQuantizedDataTypeNumericalRange(ArrayDataType data_type,
ArrayDataType GetQuantizedDataType(const Array& array,
ArrayDataType default_type);
-// Gets the quantization params for the array with the given data type and
+// Returns the quantization params for the array with the given data type and
// minmax.
void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
QuantizationParams* quantization_params);
+// Returns the quantization params for the data type and minmax values.
+template <ArrayDataType A>
+void GetQuantizationParamsFromMinMax(const MinMax& minmax,
+ QuantizationParams* quantization_params) {
+ using Integer = DataType<A>;
+ const double rmin = minmax.min;
+ const double rmax = minmax.max;
+ *quantization_params =
+ ::tflite::ChooseQuantizationParams<Integer>(rmin, rmax);
+}
+
+// Quantizes an array by setting its data type and (if constant) quantizing
+// all values in the array.
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params);
+
// Returns true if the given array, when quantized, contains only values between
// the provided clamp min/max.
// Either clamp_min or clamp_max may be +/-infinity to indicate that the value