aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc78
1 files changed, 78 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
new file mode 100644
index 0000000000..5b41c49bfa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -0,0 +1,78 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ApplyAttrsToArray(GraphTransformation* transformation, Model* model,
+ const FakeQuantOperator& fq_op,
+ const string& array_name) {
+ bool changed = false;
+ auto& annotated_array = model->GetArray(array_name);
+ if (!annotated_array.minmax) {
+ const MinMax& minmax = *fq_op.minmax;
+ annotated_array.GetOrCreateMinMax() = minmax;
+ transformation->AddMessageF(
+ "Read min/max annotation for array %s: min=%g, max=%g", array_name,
+ minmax.min, minmax.max);
+ changed = true;
+ }
+ if (fq_op.narrow_range && !annotated_array.narrow_range) {
+ annotated_array.narrow_range = true;
+ transformation->AddMessageF("Read narrow_range annotation for array %s",
+ array_name);
+ changed = true;
+ }
+ return changed;
+}
+
+} // end namespace
+
+bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
+ std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fq_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (!fq_op->minmax) {
+ // Need to be resolved first by ResolveFakeQuantArgsFromVars.
+ return false;
+ }
+
+ // At this point, this FakeQuantOperator should have a MinMax
+ // attached to it, and should only have 1 input (it should not have
+ // 2nd and 3rd input arrays giving min and max anymore).
+ CHECK(fq_op->minmax);
+ CHECK_EQ(1, fq_op->inputs.size());
+
+ return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) ||
+ ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+}
+
+} // namespace toco