aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tooling_util.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-17 11:53:29 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-17 11:55:26 -0700
commitd7b6cb66c0fc346cf55020042931c07208713c60 (patch)
tree9024111ebf15d12a631ffd7e176b9da7459dd5a0 /tensorflow/contrib/lite/toco/tooling_util.cc
parent1192c1662c5c98f55805450b4619ac2bc9c6908c (diff)
Fixes and cleanup to support more complex quantized models and adds PropagateFakeQuantNumBits.
PiperOrigin-RevId: 193232630
Diffstat (limited to 'tensorflow/contrib/lite/toco/tooling_util.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc73
1 files changed, 46 insertions, 27 deletions
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 224df9973e..ecac0c28a5 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -93,9 +93,18 @@ string ArrayDataTypeName(ArrayDataType data_type) {
}
}
-bool IsInputArray(const Model& model, const string& name) {
+bool IsInputArray(const Model& model, const string& array_name) {
for (const auto& input_array : model.flags.input_arrays()) {
- if (input_array.name() == name) {
+ if (array_name == input_array.name()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool IsOutputArray(const Model& model, const string& array_name) {
+ for (const auto& output_array : model.flags.output_arrays()) {
+ if (array_name == output_array) {
return true;
}
}
@@ -106,10 +115,8 @@ bool IsArrayConsumed(const Model& model, const string& name) {
if (GetOpWithInput(model, name)) {
return true;
}
- for (const string& model_output : model.flags.output_arrays()) {
- if (model_output == name) {
- return true;
- }
+ if (IsOutputArray(model, name)) {
+ return true;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (rnn_state.back_edge_source_array() == name) {
@@ -379,6 +386,7 @@ string HelpfulOperatorTypeName(const Operator& op) {
bool OperatorSupportsFusedActivation(OperatorType type) {
switch (type) {
case OperatorType::kConcatenation:
+ case OperatorType::kFakeQuant:
case OperatorType::kGather:
case OperatorType::kSlice:
case OperatorType::kSqueeze:
@@ -1064,16 +1072,38 @@ void FixEdgeArrays(Model* model) {
}
}
+namespace {
+void CopyArrayAttribs(const Array& source_array, Array* target_array) {
+ target_array->data_type = source_array.data_type;
+ target_array->final_data_type = source_array.final_data_type;
+ target_array->copy_shape(source_array.shape());
+
+ if (source_array.minmax) {
+ target_array->GetOrCreateMinMax() = source_array.GetMinMax();
+ } else {
+ target_array->minmax.reset();
+ }
+
+ if (source_array.quantization_params) {
+ target_array->GetOrCreateQuantizationParams() =
+ source_array.GetQuantizationParams();
+ } else {
+ target_array->quantization_params.reset();
+ }
+}
+} // namespace
+
void InsertCopyOperator(Model* model, const string& source_array_name,
const string& target_array_name) {
+ // Reshape to the same size. This should be a no-op.
+ const Array& source_array = model->GetArray(source_array_name);
+ std::vector<int> shape = source_array.shape().dims();
+
// Drop constant data from the target array as the copy will be done at
// runtime.
Array& target_array = model->GetOrCreateArray(target_array_name);
target_array.buffer.reset();
-
- // Reshape to the same size. This should be a no-op.
- const Array& source_array = model->GetArray(source_array_name);
- std::vector<int> shape = source_array.shape().dims();
+ CopyArrayAttribs(source_array, &target_array);
// Insert copy operator.
auto* copy_op = new TensorFlowReshapeOperator;
@@ -1089,6 +1119,7 @@ void CloneArray(Model* model, const string& source_array_name,
CHECK(!model->HasArray(target_array_name));
const Array& source_array = model->GetArray(source_array_name);
Array& target_array = model->GetOrCreateArray(target_array_name);
+ CopyArrayAttribs(source_array, &target_array);
if (source_array.minmax) {
const auto& smm = source_array.GetMinMax();
@@ -1513,14 +1544,9 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) {
if (model.IsOptionalArray(array_name)) return false;
// The model's input and output arrays are externally allocated.
// They are not transient arrays.
- if (IsInputArray(model, array_name)) {
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
return false;
}
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
- }
const auto& array = &model.GetArray(array_name);
// An array with a constant buffer isn't a transient array.
if (!!array->buffer) {
@@ -1898,15 +1924,8 @@ int AxesCount(AxesOrder axes_order) {
}
bool IsDiscardableArray(const Model& model, const string& array_name) {
- for (const auto& input_array : model.flags.input_arrays()) {
- if (array_name == input_array.name()) {
- return false;
- }
- }
- for (const string& output_array : model.flags.output_arrays()) {
- if (array_name == output_array) {
- return false;
- }
+ if (IsInputArray(model, array_name) || IsOutputArray(model, array_name)) {
+ return false;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
@@ -1960,8 +1979,8 @@ void CheckFinalDataTypesSatisfied(const Model& model) {
CHECK(array.final_data_type == array.data_type)
<< "Array \"" << array_entry.first
<< "\" has mis-matching actual and final data types ("
- << static_cast<int>(array.data_type) << ","
- << static_cast<int>(array.final_data_type) << ").";
+ << ArrayDataTypeName(array.data_type) << ","
+ << ArrayDataTypeName(array.final_data_type) << ").";
}
}
}