aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-05 14:05:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 14:07:30 -0700
commitf7d00f3d67c47ffc3656c4f2868032b72cd2122b (patch)
tree0a06caace5b82d4a1229d5fe2ace467af8c6b04e /tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
parent310249066320f1ddc7fe544b4c351aaf89ce3c9c (diff)
quantized LSTM support improvements
PiperOrigin-RevId: 191794956
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc14
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
index 935da9f966..183b3d3f2e 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -78,15 +78,21 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
image_input_op->outputs = {dequantized_input_name};
model->operators.emplace(model->operators.begin(), image_input_op);
- CHECK(input_array.final_data_type == ArrayDataType::kUint8);
- input_array.data_type = ArrayDataType::kUint8;
dequantized_input_array.data_type = ArrayDataType::kFloat;
const auto& input_minmax = input_array.GetMinMax();
auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
dequantized_input_minmax = input_minmax;
auto& input_qparams = input_array.GetOrCreateQuantizationParams();
- GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
- &input_qparams);
+ input_array.data_type = input_array.final_data_type;
+ if (input_array.data_type == ArrayDataType::kUint8) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax,
+ &input_qparams);
+ } else if (input_array.data_type == ArrayDataType::kInt16) {
+ GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax,
+ &input_qparams);
+ } else {
+ LOG(FATAL) << "unhandled data type";
+ }
transformation->AddMessageF(
"Created %s"