diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-05 14:05:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-05 14:07:30 -0700 |
commit | f7d00f3d67c47ffc3656c4f2868032b72cd2122b (patch) | |
tree | 0a06caace5b82d4a1229d5fe2ace467af8c6b04e /tensorflow/contrib/lite/toco/tooling_util.cc | |
parent | 310249066320f1ddc7fe544b4c351aaf89ce3c9c (diff) |
quantized LSTM support improvements
PiperOrigin-RevId: 191794956
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 56fa8f4b69..61d08fa13f 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -1378,12 +1378,22 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { const float mean_value = input_array_proto.mean_value(); const float std_value = input_array_proto.std_value(); MinMax input_minmax; - input_minmax.min = (0.f - mean_value) / std_value; - input_minmax.max = (255.f - mean_value) / std_value; + float qmin = 0, qmax = 255; + if (input_array.data_type == ArrayDataType::kInt16) { + qmin = -32768; + qmax = 32767; + } + input_minmax.min = (qmin - mean_value) / std_value; + input_minmax.max = (qmax - mean_value) / std_value; if (input_array.minmax) { if (input_array_proto.has_mean_value() || input_array_proto.has_std_value()) { - CHECK(input_minmax == *input_array.minmax) + const double width = input_minmax.max - input_minmax.min; + const double kMinMaxAllowedDiff = 1e-6 * width; + CHECK(std::abs(input_minmax.min - input_array.minmax->min) < + kMinMaxAllowedDiff && + std::abs(input_minmax.max - input_array.minmax->max) < + kMinMaxAllowedDiff) << input_minmax.min << ", " << input_minmax.max << " != " << input_array.minmax->min << ", " << input_array.minmax->max; @@ -2000,7 +2010,7 @@ void FinishBuildingRNNStates(Model* model) { } } -void UseArraysExtraInfo(Model* model) { +void UseArraysExtraInfo(Model* model, bool quantize_output) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { if (!model->HasArray(entry.name())) { continue; @@ -2012,7 +2022,7 @@ void UseArraysExtraInfo(Model* model) { minmax.min = entry.min(); minmax.max = entry.max(); } - if (entry.has_data_type()) { + if (entry.has_data_type() && quantize_output) { array.final_data_type = ConvertIODataTypeToArrayDataType(entry.data_type()); } |