diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-17 11:53:29 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-17 11:55:26 -0700 |
commit | d7b6cb66c0fc346cf55020042931c07208713c60 (patch) | |
tree | 9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/tooling_util.cc | |
parent | 1192c1662c5c98f55805450b4619ac2bc9c6908c (diff) |
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tooling_util.cc | 73 |
1 files changed, 46 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 224df9973e..ecac0c28a5 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -93,9 +93,18 @@ string ArrayDataTypeName(ArrayDataType data_type) { } } -bool IsInputArray(const Model& model, const string& name) { +bool IsInputArray(const Model& model, const string& array_name) { for (const auto& input_array : model.flags.input_arrays()) { - if (input_array.name() == name) { + if (array_name == input_array.name()) { + return true; + } + } + return false; +} + +bool IsOutputArray(const Model& model, const string& array_name) { + for (const auto& output_array : model.flags.output_arrays()) { + if (array_name == output_array) { return true; } } @@ -106,10 +115,8 @@ bool IsArrayConsumed(const Model& model, const string& name) { if (GetOpWithInput(model, name)) { return true; } - for (const string& model_output : model.flags.output_arrays()) { - if (model_output == name) { - return true; - } + if (IsOutputArray(model, name)) { + return true; } for (const auto& rnn_state : model.flags.rnn_states()) { if (rnn_state.back_edge_source_array() == name) { @@ -379,6 +386,7 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { case OperatorType::kConcatenation: + case OperatorType::kFakeQuant: case OperatorType::kGather: case OperatorType::kSlice: case OperatorType::kSqueeze: @@ -1064,16 +1072,38 @@ void FixEdgeArrays(Model* model) { } } +namespace { +void CopyArrayAttribs(const Array& source_array, Array* target_array) { + target_array->data_type = source_array.data_type; + target_array->final_data_type = source_array.final_data_type; + target_array->copy_shape(source_array.shape()); + + if (source_array.minmax) { + target_array->GetOrCreateMinMax() = source_array.GetMinMax(); + } else { + target_array->minmax.reset(); + } + + if (source_array.quantization_params) { + target_array->GetOrCreateQuantizationParams() = + source_array.GetQuantizationParams(); + } else { + target_array->quantization_params.reset(); + } +} +} // namespace + void InsertCopyOperator(Model* model, const string& source_array_name, const string& target_array_name) { + // Reshape to the same size. This should be a no-op. + const Array& source_array = model->GetArray(source_array_name); + std::vector<int> shape = source_array.shape().dims(); + // Drop constant data from the target array as the copy will be done at // runtime. Array& target_array = model->GetOrCreateArray(target_array_name); target_array.buffer.reset(); - - // Reshape to the same size. This should be a no-op. - const Array& source_array = model->GetArray(source_array_name); - std::vector<int> shape = source_array.shape().dims(); + CopyArrayAttribs(source_array, &target_array); // Insert copy operator. auto* copy_op = new TensorFlowReshapeOperator; @@ -1089,6 +1119,7 @@ void CloneArray(Model* model, const string& source_array_name, CHECK(!model->HasArray(target_array_name)); const Array& source_array = model->GetArray(source_array_name); Array& target_array = model->GetOrCreateArray(target_array_name); + CopyArrayAttribs(source_array, &target_array); if (source_array.minmax) { const auto& smm = source_array.GetMinMax(); @@ -1513,14 +1544,9 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { if (model.IsOptionalArray(array_name)) return false; // The model's input and output arrays are externally allocated. // They are not transient arrays. - if (IsInputArray(model, array_name)) { + if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { return false; } - for (const string& output_array : model.flags.output_arrays()) { - if (array_name == output_array) { - return false; - } - } const auto& array = &model.GetArray(array_name); // An array with a constant buffer isn't a transient array. if (!!array->buffer) { @@ -1898,15 +1924,8 @@ int AxesCount(AxesOrder axes_order) { } bool IsDiscardableArray(const Model& model, const string& array_name) { - for (const auto& input_array : model.flags.input_arrays()) { - if (array_name == input_array.name()) { - return false; - } - } - for (const string& output_array : model.flags.output_arrays()) { - if (array_name == output_array) { - return false; - } + if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) { + return false; } for (const auto& rnn_state : model.flags.rnn_states()) { if (!rnn_state.discardable()) { @@ -1960,8 +1979,8 @@ void CheckFinalDataTypesSatisfied(const Model& model) { CHECK(array.final_data_type == array.data_type) << "Array \"" << array_entry.first << "\" has mis-matching actual and final data types (" - << static_cast<int>(array.data_type) << "," - << static_cast<int>(array.final_data_type) << ")."; + << ArrayDataTypeName(array.data_type) << "," + << ArrayDataTypeName(array.final_data_type) << ")."; } } } |