aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc17
1 files changed, 5 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
index 0f2592d05f..3ad6b0ec6f 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc
@@ -30,15 +30,9 @@ namespace {
bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
ArrayDataType new_data_type,
const MinMax* new_minmax) {
- // The code below assumes kInt16, see
- // GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>
- if (new_data_type != ArrayDataType::kInt16) {
- return false;
- }
-
- bool changed = false;
// Ensure the array ends up in the new type (if it hasn't yet been quantized).
- if ((array->final_data_type != new_data_type)) {
+ bool changed = false;
+ if (array->final_data_type != new_data_type) {
array->final_data_type = new_data_type;
changed = true;
}
@@ -72,12 +66,10 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
"Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min,
array_minmax.max, ArrayDataTypeName(array->data_type), min, max,
ArrayDataTypeName(new_data_type));
-
array_minmax.min = min;
array_minmax.max = max;
- GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(
- array_minmax, array->quantization_params.get());
-
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ *array, new_data_type, array->quantization_params.get());
// Directly change the type as the array was already quantized.
array->data_type = new_data_type;
changed = true;
@@ -95,6 +87,7 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array,
changed = true;
}
}
+
return changed;
}