aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/quantize.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc29
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 38699a62b5..f6ce3b3ecb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -50,6 +50,7 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kSqueeze || type == OperatorType::kPad ||
type == OperatorType::kPadV2 || type == OperatorType::kReshape ||
type == OperatorType::kTanh || type == OperatorType::kMul ||
+ type == OperatorType::kBatchToSpaceND ||
type == OperatorType::kSpaceToBatchND ||
type == OperatorType::kSpaceToDepth ||
type == OperatorType::kStridedSlice ||
@@ -59,7 +60,8 @@ bool SupportsQuantization(const Operator& op) {
type == OperatorType::kGreater ||
type == OperatorType::kGreaterEqual || type == OperatorType::kLess ||
type == OperatorType::kLessEqual || type == OperatorType::kSelect ||
- type == OperatorType::kArgMax;
+ type == OperatorType::kArgMax || type == OperatorType::kRelu ||
+ type == OperatorType::kRelu1 || type == OperatorType::kRelu6;
}
const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
@@ -211,13 +213,15 @@ bool ChooseQuantizationForOperatorInput(
if (op.type == OperatorType::kLstmCell) {
if (input_index == LstmCellOperator::PREV_STATE_INPUT) {
*quantized_data_type = ArrayDataType::kInt16;
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
return true;
}
}
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
transformation->AddMessageF(
"For input array %s with min=%g, max=%g, chose to quantize as %s (f=%s) "
"with zero_point=%d, scale=%g",
@@ -325,12 +329,13 @@ bool ChooseQuantizationForOperatorOutput(
output, OperatorTypeName(op.type));
return true;
}
- if ((op.type == OperatorType::kDepthToSpace) ||
- (op.type == OperatorType::kSpaceToDepth) ||
- (op.type == OperatorType::kReshape) ||
- (op.type == OperatorType::kSplit) ||
- (op.type == OperatorType::kConcatenation &&
- model->flags.change_concat_input_ranges())) {
+ if ((op.type == OperatorType::kConcatenation &&
+ model->flags.change_concat_input_ranges()) ||
+ op.type == OperatorType::kDepthToSpace ||
+ op.type == OperatorType::kSpaceToDepth ||
+ op.type == OperatorType::kReshape || op.type == OperatorType::kSplit ||
+ op.type == OperatorType::kRelu || op.type == OperatorType::kRelu1 ||
+ op.type == OperatorType::kRelu6) {
int data_input_index = 0;
if (op.type == OperatorType::kSplit) {
data_input_index = 1;
@@ -356,12 +361,14 @@ bool ChooseQuantizationForOperatorOutput(
if (output_index == LstmCellOperator::STATE_OUTPUT ||
output_index == LstmCellOperator::ACTIV_TEMP) {
*quantized_data_type = ArrayDataType::kInt16;
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
return true;
}
}
*quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8);
- GetQuantizationParams(*quantized_data_type, minmax, quantization_params);
+ ChooseQuantizationParamsForArrayAndQuantizedDataType(
+ array, *quantized_data_type, quantization_params);
transformation->AddMessageF(
"For output array %s with min=%g, max=%g"
", chose to quantize as %s with zero_point=%d"