diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc | 17 |
1 files changed, 5 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc index 0f2592d05f..3ad6b0ec6f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fake_quant_num_bits.cc @@ -30,15 +30,9 @@ namespace { bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, ArrayDataType new_data_type, const MinMax* new_minmax) { - // The code below assumes kInt16, see - // GetQuantizationParamsFromMinMax<ArrayDataType::kInt16> - if (new_data_type != ArrayDataType::kInt16) { - return false; - } - - bool changed = false; // Ensure the array ends up in the new type (if it hasn't yet been quantized). - if ((array->final_data_type != new_data_type)) { + bool changed = false; + if (array->final_data_type != new_data_type) { array->final_data_type = new_data_type; changed = true; } @@ -72,12 +66,10 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, "Rescaling min/max from %g,%g (%s) to %g,%g (%s)", array_minmax.min, array_minmax.max, ArrayDataTypeName(array->data_type), min, max, ArrayDataTypeName(new_data_type)); - array_minmax.min = min; array_minmax.max = max; - GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>( - array_minmax, array->quantization_params.get()); - + ChooseQuantizationParamsForArrayAndQuantizedDataType( + *array, new_data_type, array->quantization_params.get()); // Directly change the type as the array was already quantized. array->data_type = new_data_type; changed = true; @@ -95,6 +87,7 @@ bool ChangeArrayDataType(GraphTransformation* transformation, Array* array, changed = true; } } + return changed; } |