diff options
author | 2018-08-31 14:44:49 -0700 | |
---|---|---|
committer | 2018-08-31 15:00:21 -0700 | |
commit | bd3ef4faf91aab31cbee3cbe586a5e1b277afbb6 (patch) | |
tree | a4e1114bf9029640078a192db72bd7eb6a02c500 /tensorflow/contrib/lite/tools | |
parent | e082d5208e56d3d8f69544781bebf830eae82de7 (diff) |
Hybrid operations need either all or none of their tensors quantized.
PiperOrigin-RevId: 211147312
Diffstat (limited to 'tensorflow/contrib/lite/tools')
-rw-r--r-- | tensorflow/contrib/lite/tools/optimize/quantize_weights.cc | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc index 10a66dd351..df8433dd9b 100644 --- a/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc +++ b/tensorflow/contrib/lite/tools/optimize/quantize_weights.cc @@ -168,6 +168,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, bool eval_hybrid = IsHybridEvaluationOp(op, op_code); + bool skipped_tensor = false; std::vector<int32_t> op_input_indices = GetWeightInputIndices(op_code); for (const int32_t op_input_idx : op_input_indices) { int32_t tensor_idx = op->inputs[op_input_idx]; @@ -177,6 +178,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, if (CountTensorConsumers(model, subgraph, tensor_idx) != 1) { LOG(INFO) << "Skipping quantization of tensor that is shared between " "multiple multiple operations."; + skipped_tensor = true; continue; } @@ -184,6 +186,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, if (tensor->type != TensorType_FLOAT32) { LOG(INFO) << "Skipping quantization of tensor that is not type float."; + skipped_tensor = true; continue; } @@ -191,6 +194,7 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, if (num_elements < kWeightsMinSize) { LOG(INFO) << "Skipping quantization of tensor because it has fewer than " << kWeightsMinSize << " elements (" << num_elements << ")."; + skipped_tensor = true; continue; } @@ -203,6 +207,12 @@ std::vector<TensorInfo> GetQuantizableTensorsFromOperator(const ModelT* model, tensor_infos.push_back(tensor_info); } + // For hybrid operations we either need to quantize all tensors or none. So + // if we skipped any tensors we need to return no quantized tensors. + if (eval_hybrid && skipped_tensor) { + return {}; + } + return tensor_infos; } |