diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-05 14:05:03 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-05 14:07:30 -0700 |
commit | f7d00f3d67c47ffc3656c4f2868032b72cd2122b (patch) | |
tree | 0a06caace5b82d4a1229d5fe2ace467af8c6b04e /tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc | |
parent | 310249066320f1ddc7fe544b4c351aaf89ce3c9c (diff) |
quantized LSTM support improvements
PiperOrigin-RevId: 191794956
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc index 935da9f966..183b3d3f2e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc @@ -78,15 +78,21 @@ bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op, image_input_op->outputs = {dequantized_input_name}; model->operators.emplace(model->operators.begin(), image_input_op); - CHECK(input_array.final_data_type == ArrayDataType::kUint8); - input_array.data_type = ArrayDataType::kUint8; dequantized_input_array.data_type = ArrayDataType::kFloat; const auto& input_minmax = input_array.GetMinMax(); auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax(); dequantized_input_minmax = input_minmax; auto& input_qparams = input_array.GetOrCreateQuantizationParams(); - GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax, - &input_qparams); + input_array.data_type = input_array.final_data_type; + if (input_array.data_type == ArrayDataType::kUint8) { + GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(input_minmax, + &input_qparams); + } else if (input_array.data_type == ArrayDataType::kInt16) { + GetQuantizationParamsFromMinMax<ArrayDataType::kInt16>(input_minmax, + &input_qparams); + } else { + LOG(FATAL) << "unhandled data type"; + } transformation->AddMessageF( "Created %s" |