aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-05 14:05:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 14:07:30 -0700
commitf7d00f3d67c47ffc3656c4f2868032b72cd2122b (patch)
tree0a06caace5b82d4a1229d5fe2ace467af8c6b04e /tensorflow/contrib/lite/toco/tooling_util.cc
parent310249066320f1ddc7fe544b4c351aaf89ce3c9c (diff)
quantized LSTM support improvements
PiperOrigin-RevId: 191794956
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc20
1 files changed, 15 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 56fa8f4b69..61d08fa13f 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -1378,12 +1378,22 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
const float mean_value = input_array_proto.mean_value();
const float std_value = input_array_proto.std_value();
MinMax input_minmax;
- input_minmax.min = (0.f - mean_value) / std_value;
- input_minmax.max = (255.f - mean_value) / std_value;
+ float qmin = 0, qmax = 255;
+ if (input_array.data_type == ArrayDataType::kInt16) {
+ qmin = -32768;
+ qmax = 32767;
+ }
+ input_minmax.min = (qmin - mean_value) / std_value;
+ input_minmax.max = (qmax - mean_value) / std_value;
if (input_array.minmax) {
if (input_array_proto.has_mean_value() ||
input_array_proto.has_std_value()) {
- CHECK(input_minmax == *input_array.minmax)
+ const double width = input_minmax.max - input_minmax.min;
+ const double kMinMaxAllowedDiff = 1e-6 * width;
+ CHECK(std::abs(input_minmax.min - input_array.minmax->min) <
+ kMinMaxAllowedDiff &&
+ std::abs(input_minmax.max - input_array.minmax->max) <
+ kMinMaxAllowedDiff)
<< input_minmax.min << ", " << input_minmax.max
<< " != " << input_array.minmax->min << ", "
<< input_array.minmax->max;
@@ -2000,7 +2010,7 @@ void FinishBuildingRNNStates(Model* model) {
}
}
-void UseArraysExtraInfo(Model* model) {
+void UseArraysExtraInfo(Model* model, bool quantize_output) {
for (const auto& entry : model->flags.arrays_extra_info().entries()) {
if (!model->HasArray(entry.name())) {
continue;
@@ -2012,7 +2022,7 @@ void UseArraysExtraInfo(Model* model) {
minmax.min = entry.min();
minmax.max = entry.max();
}
- if (entry.has_data_type()) {
+ if (entry.has_data_type() && quantize_output) {
array.final_data_type =
ConvertIODataTypeToArrayDataType(entry.data_type());
}