diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc | 46 |
1 files changed, 42 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc index efb7bb2184..058f314b33 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc @@ -25,6 +25,37 @@ limitations under the License. namespace toco { +template <ArrayDataType A> +void GetBoundsForQuantizedDataType(double* min, double* max) { + using limits = std::numeric_limits<DataType<A>>; + *min = limits::min(); + *max = limits::max(); +} + +void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type, + double* min, double* max) { + switch (quantized_data_type) { + case ArrayDataType::kUint8: + return GetBoundsForQuantizedDataType<ArrayDataType::kUint8>(min, max); + case ArrayDataType::kInt8: + return GetBoundsForQuantizedDataType<ArrayDataType::kInt8>(min, max); + case ArrayDataType::kUint16: + return GetBoundsForQuantizedDataType<ArrayDataType::kUint16>(min, max); + case ArrayDataType::kInt16: + return GetBoundsForQuantizedDataType<ArrayDataType::kInt16>(min, max); + case ArrayDataType::kUint32: + return GetBoundsForQuantizedDataType<ArrayDataType::kUint32>(min, max); + case ArrayDataType::kInt32: + return GetBoundsForQuantizedDataType<ArrayDataType::kInt32>(min, max); + case ArrayDataType::kUint64: + return GetBoundsForQuantizedDataType<ArrayDataType::kUint64>(min, max); + case ArrayDataType::kInt64: + return GetBoundsForQuantizedDataType<ArrayDataType::kInt64>(min, max); + default: + LOG(FATAL) << "unhandled quantized data type"; + } +} + bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const auto fakequant_it = model->operators.begin() + op_index; const auto* fakequant_base_op = fakequant_it->get(); @@ -76,14 +107,21 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) { const int size = input_buffer.data.size(); output_buffer.data.resize(size); QuantizationParams qparams; - GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(*fakequant_op->minmax, - &qparams); + ChooseQuantizationParamsForArrayAndQuantizedDataType( + output_array, quantized_data_type, &qparams); + double quantized_min, quantized_max; + GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min, + &quantized_max); + if (fakequant_op->narrow_range) { + quantized_min++; + } + for (int i = 0; i < size; i++) { const double src_val = input_buffer.data[i]; const double unclamped_quantized_val = std::round(qparams.zero_point + src_val / qparams.scale); - const double quantized_val = - std::min(255., std::max(0., unclamped_quantized_val)); + const double quantized_val = std::min( + quantized_max, std::max(quantized_min, unclamped_quantized_val)); const double dst_val = qparams.scale * (quantized_val - qparams.zero_point); output_buffer.data[i] = dst_val; } |