aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc46
1 files changed, 42 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
index efb7bb2184..058f314b33 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -25,6 +25,37 @@ limitations under the License.
namespace toco {
+template <ArrayDataType A>
+void GetBoundsForQuantizedDataType(double* min, double* max) {
+ using limits = std::numeric_limits<DataType<A>>;
+ *min = limits::min();
+ *max = limits::max();
+}
+
+void GetBoundsForQuantizedDataType(ArrayDataType quantized_data_type,
+ double* min, double* max) {
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint8>(min, max);
+ case ArrayDataType::kInt8:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt8>(min, max);
+ case ArrayDataType::kUint16:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint16>(min, max);
+ case ArrayDataType::kInt16:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt16>(min, max);
+ case ArrayDataType::kUint32:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint32>(min, max);
+ case ArrayDataType::kInt32:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt32>(min, max);
+ case ArrayDataType::kUint64:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kUint64>(min, max);
+ case ArrayDataType::kInt64:
+ return GetBoundsForQuantizedDataType<ArrayDataType::kInt64>(min, max);
+ default:
+ LOG(FATAL) << "unhandled quantized data type";
+ }
+}
+
bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const auto fakequant_it = model->operators.begin() + op_index;
const auto* fakequant_base_op = fakequant_it->get();
@@ -76,14 +107,21 @@ bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
const int size = input_buffer.data.size();
output_buffer.data.resize(size);
QuantizationParams qparams;
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(*fakequant_op->minmax,
- &qparams);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ output_array, quantized_data_type, &qparams);
+ double quantized_min, quantized_max;
+ GetBoundsForQuantizedDataType(quantized_data_type, &quantized_min,
+ &quantized_max);
+ if (fakequant_op->narrow_range) {
+ quantized_min++;
+ }
+
for (int i = 0; i < size; i++) {
const double src_val = input_buffer.data[i];
const double unclamped_quantized_val =
std::round(qparams.zero_point + src_val / qparams.scale);
- const double quantized_val =
- std::min(255., std::max(0., unclamped_quantized_val));
+ const double quantized_val = std::min(
+ quantized_max, std::max(quantized_min, unclamped_quantized_val));
const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
output_buffer.data[i] = dst_val;
}