aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc69
1 files changed, 42 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
index d74cad9a62..44733391f5 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantization_util.cc
@@ -74,46 +74,54 @@ ArrayDataType GetQuantizedDataType(const Array& array,
}
}
-void GetQuantizationParams(ArrayDataType data_type, const MinMax& minmax,
- QuantizationParams* quantization_params) {
- switch (data_type) {
+template <ArrayDataType A>
+void ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ const Array& array, QuantizationParams* quantization_params) {
+ *quantization_params = ::tflite::ChooseQuantizationParams<DataType<A>>(
+ array.minmax->min, array.minmax->max, array.narrow_range);
+}
+
+void ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ const Array& array, ArrayDataType quantized_data_type,
+ QuantizationParams* quantization_params) {
+ switch (quantized_data_type) {
case ArrayDataType::kInt8:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt8>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt8>(array, quantization_params);
break;
case ArrayDataType::kUint8:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint8>(array, quantization_params);
break;
case ArrayDataType::kInt16:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt16>(array, quantization_params);
break;
case ArrayDataType::kUint16:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint16>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint16>(array, quantization_params);
break;
case ArrayDataType::kInt32:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt32>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt32>(array, quantization_params);
break;
case ArrayDataType::kUint32:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint32>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint32>(array, quantization_params);
break;
case ArrayDataType::kInt64:
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt64>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kInt64>(array, quantization_params);
break;
case ArrayDataType::kUint64:
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint64>(
- minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType<
+ ArrayDataType::kUint64>(array, quantization_params);
break;
case ArrayDataType::kFloat:
case ArrayDataType::kNone:
default:
LOG(FATAL) << "Unhandled final quantization type "
- << static_cast<int>(data_type);
+ << static_cast<int>(quantized_data_type);
}
}
@@ -121,8 +129,8 @@ namespace {
template <ArrayDataType A>
std::unique_ptr<GenericBuffer> QuantizeBuffer(
- const GenericBuffer& buffer,
- const QuantizationParams& quantization_params) {
+ const Array& array, const QuantizationParams& quantization_params) {
+ const GenericBuffer& buffer = *array.buffer;
const auto inverse_scale = 1. / quantization_params.scale;
CHECK(buffer.type == ArrayDataType::kFloat);
const auto& float_buffer =
@@ -140,8 +148,15 @@ std::unique_ptr<GenericBuffer> QuantizeBuffer(
} else {
scaled_val = quantization_params.zero_point + inverse_scale * src_val;
}
- quantized_buffer->data[i] =
- tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ auto integer_val = tflite::SafeCast<DataType<A>>(std::round(scaled_val));
+ // In addition to its effect on the choice of quantization params upstream
+ // of here, narrow_range also means nudge the min quantized value by +1,
+ // so e.g. uint8 values get constrained to [1, 255].
+ if (integer_val == std::numeric_limits<DataType<A>>::min() &&
+ array.narrow_range) {
+ integer_val++;
+ }
+ quantized_buffer->data[i] = integer_val;
}
return std::unique_ptr<GenericBuffer>(quantized_buffer);
}
@@ -155,7 +170,7 @@ void QuantizeArray(GraphTransformation* transformation, Model* model,
CHECK(!array.quantization_params);
array.GetOrCreateQuantizationParams() = quantization_params;
if (array.buffer) {
- array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ array.buffer = QuantizeBuffer<A>(array, quantization_params);
}
array.data_type = A;
array.final_data_type = A;
@@ -210,8 +225,8 @@ bool IsArrayQuantizedRangeSubset(GraphTransformation* transformation,
} else {
// Work around cases where we are asking for this prior to the Quantize
// transformation having added the quantization_params.
- GetQuantizationParams(quantized_data_type, *array.minmax,
- &quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, quantized_data_type, &quantization_params);
transformation->AddMessageF(
"No quantization params - infering from data type %s with minmax "
"%g,%g as zero_point=%g, scale=%g",