aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc69
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc223
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc56
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc42
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc57
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc98
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc300
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc326
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc108
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h186
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc229
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc170
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc396
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc103
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc120
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc142
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc1129
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc467
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc105
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc59
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc60
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc38
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc113
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc40
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc68
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc107
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc87
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc92
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc122
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc135
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc247
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc196
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc76
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc175
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc51
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc55
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc93
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc49
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc52
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc62
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc86
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc106
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc63
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc54
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc123
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc97
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD31
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc221
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc73
53 files changed, 7478 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
new file mode 100644
index 0000000000..bf454c40c7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/convert_pure_conv_to_depthwise.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ConvertPureConvToDepthwise::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ const auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->stride_width != conv_op->stride_height) {
+ return false;
+ }
+ auto& weights_array = model->GetArray(conv_op->inputs[1]);
+ if (!weights_array.buffer) {
+ // Yield until the weights are resolved as a constant array.
+ return false;
+ }
+ if (weights_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (weights_array.shape().dims(3) != 1) {
+ // Not a pure convolution: Conv does accumulation across the depth
+ // dimension.
+ return false;
+ }
+ // At this point we know we have a pure conv. Rewrite it as DepthwiseConv.
+ AddMessageF(
+ "%s is purely convolutional (input/weights depth is 1), replacing it by "
+ "a DepthwiseConv.",
+ LogName(*conv_op));
+ auto* depthwiseconv_op = new DepthwiseConvOperator;
+ // Conv and DepthwiseConv take the same inputs
+ depthwiseconv_op->inputs = conv_op->inputs;
+ // Conv may have a 2nd output for im2col
+ depthwiseconv_op->outputs = {conv_op->outputs[0]};
+ if (conv_op->outputs.size() > 1) {
+ // delete the im2col array.
+ model->arrays.erase(conv_op->outputs[1]);
+ }
+ depthwiseconv_op->fused_activation_function =
+ conv_op->fused_activation_function;
+ // Let PropagateFixedSizes recompute fixed padding, just in case some day it
+ // may be different for Conv vs DepthwiseConv.
+ depthwiseconv_op->padding.type = conv_op->padding.type;
+ depthwiseconv_op->stride_height = conv_op->stride_height;
+ depthwiseconv_op->stride_width = conv_op->stride_width;
+ depthwiseconv_op->depth_multiplier = weights_array.shape().dims(0);
+ // Replace the operator in the graph.
+ const auto depthwiseconv_it =
+ model->operators.emplace(conv_it, depthwiseconv_op);
+ conv_it = depthwiseconv_it + 1;
+ CHECK_EQ(conv_it->get(), conv_op);
+ model->operators.erase(conv_it);
+ // Shuffle the weights.
+ const auto& weights_shape = weights_array.shape();
+ auto& weights_buffer =
+ weights_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const std::vector<float>& conv_weights_data = weights_buffer.data;
+ std::vector<float> depthwise_conv_weights_data(conv_weights_data.size());
+ const int depth = weights_shape.dims(0);
+ const int width = weights_shape.dims(1);
+ const int height = weights_shape.dims(2);
+ const int width_height = width * height;
+ for (int c = 0; c < depth; c++) {
+ for (int xy = 0; xy < width_height; xy++) {
+ depthwise_conv_weights_data[c + depth * xy] =
+ conv_weights_data[xy + width_height * c];
+ }
+ }
+ *weights_array.mutable_shape()->mutable_dims() = {1, width, height, depth};
+ weights_buffer.data = depthwise_conv_weights_data;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
new file mode 100644
index 0000000000..1735b51e5b
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/create_im2col_arrays.cc
@@ -0,0 +1,69 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool CreateIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() == 2) {
+ // We already have an im2col array
+ return false;
+ }
+ const auto& weights_array = *model->arrays[conv_op->inputs[1]];
+ if (!weights_array.has_shape()) {
+ // We need to yield until weights dims have been resolved, because
+ // from the weights dims we determine whether an im2col array is
+ // needed.
+ return false;
+ }
+ const auto& weights_shape = weights_array.shape();
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ if (kwidth == 1 && kheight == 1 && conv_op->stride_width == 1 &&
+ conv_op->stride_height == 1) {
+ // 1x1 unstrided conv does not need an im2col array.
+ return false;
+ }
+
+ // Create the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 1);
+ const string& im2col_array_name =
+ AvailableArrayName(*model, conv_op->inputs[0] + "_im2col");
+ model->GetOrCreateArray(im2col_array_name);
+ conv_op->outputs.push_back(im2col_array_name);
+ AddMessageF(
+ "Created an im2col array for %s, with %dx%d kernel and stride_width=%d, "
+ "stride_height=%d",
+ LogName(*conv_op), kwidth, kheight, conv_op->stride_width,
+ conv_op->stride_height);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
new file mode 100644
index 0000000000..b89e3f5310
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/dequantize.cc
@@ -0,0 +1,223 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <ArrayDataType A>
+void DequantizeBuffer(Array* array) {
+ const auto old_data = array->GetBuffer<A>().data;
+ array->buffer = nullptr;
+ array->data_type = ArrayDataType::kFloat;
+ auto& new_data = array->GetMutableBuffer<ArrayDataType::kFloat>().data;
+ new_data.resize(old_data.size());
+ const auto& qparams = array->GetQuantizationParams();
+ for (int i = 0; i < old_data.size(); i++) {
+ new_data[i] = qparams.scale * (old_data[i] - qparams.zero_point);
+ }
+}
+
+std::vector<std::unique_ptr<Operator>>::iterator FindFirstOpWithInput(
+ Model* model, const string& array_name) {
+ for (auto it = model->operators.begin(); it != model->operators.end(); ++it) {
+ for (const auto& input : it->get()->inputs) {
+ if (input == array_name) {
+ return it;
+ }
+ }
+ }
+ return model->operators.end();
+}
+
+void ClearArrayQuantizationParams(const string& array_name, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ CHECK(array->quantization_params);
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == array_name) {
+ auto& qparams = *array->quantization_params;
+ const double new_std_value = 1. / qparams.scale;
+ const double new_mean_value = qparams.zero_point;
+ if (input_array.has_std_value()) {
+ CHECK_LE(std::abs(new_std_value - input_array.std_value()), 0.001);
+ } else {
+ input_array.set_std_value(new_std_value);
+ }
+ if (input_array.has_mean_value()) {
+ CHECK_LE(std::abs(new_mean_value - input_array.mean_value()), 0.001);
+ } else {
+ input_array.set_mean_value(new_mean_value);
+ }
+ }
+ }
+ array->quantization_params = nullptr;
+}
+
+bool DequantizeArray(const string& array_name,
+ GraphTransformation* transformation, Model* model) {
+ auto* array = model->arrays.at(array_name).get();
+ if (!array->quantization_params) {
+ return false;
+ }
+ transformation->AddMessageF("Dequantizing array: %s", array_name);
+
+ // Dequantize any buffer
+ if (array->buffer) {
+ if (array->data_type == ArrayDataType::kUint8) {
+ DequantizeBuffer<ArrayDataType::kUint8>(array);
+ } else if (array->data_type == ArrayDataType::kInt32) {
+ DequantizeBuffer<ArrayDataType::kInt32>(array);
+ } else {
+ LOG(FATAL) << "Unhandled data type";
+ }
+ CHECK(array->data_type == ArrayDataType::kFloat);
+ CHECK(array->buffer->type == ArrayDataType::kFloat);
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+ return true;
+ } else {
+ array->data_type = ArrayDataType::kFloat;
+ }
+
+ // Clear quantization params, officially makes this a non-quantized array.
+ ClearArrayQuantizationParams(array_name, model);
+
+ if (array->buffer) {
+ return true;
+ }
+
+ auto* op_outputting_array = GetOpWithOutput(*model, array_name);
+ if (op_outputting_array) {
+ if (op_outputting_array->type == OperatorType::kTensorFlowReshape) {
+ return true;
+ }
+ }
+
+ // If there was no minmax info, we can return now. Indeed,
+ // the below only serves to create a FakeQuant node, but some arrays are
+ // quantized without MinMax (see the CHECK above) and that corresponds to
+ // places where a FakeQuant node is actually not wanted, because the
+ // quantization params are meant to be inferred in another way (e.g. bias
+ // vector for a Conv op, see their special-casing in quantize.cc).
+ if (!array->minmax) {
+ return true;
+ }
+
+ // Determine whether to insert a FakeQuant before or after
+ // this array.
+ bool must_insert_fakequant_before = false;
+ bool must_insert_fakequant_after = false;
+ if (IsInputArray(*model, array_name)) {
+ must_insert_fakequant_after = true;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (array_name == output_array) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (array_name == rnn_state.state_array()) {
+ must_insert_fakequant_after = true;
+ }
+ if (array_name == rnn_state.back_edge_source_array()) {
+ must_insert_fakequant_before = true;
+ }
+ }
+ CHECK(!(must_insert_fakequant_before && must_insert_fakequant_after));
+
+ // Create and insert the FakeQuant node
+ auto* fakequant_op = new FakeQuantOperator;
+ model->operators.emplace(FindFirstOpWithInput(model, array_name),
+ fakequant_op);
+ const string& new_array_name = AvailableArrayName(*model, array_name);
+ auto& new_array = model->GetOrCreateArray(new_array_name);
+ new_array.data_type = ArrayDataType::kFloat;
+ new_array.copy_shape(array->shape());
+ new_array.GetOrCreateMinMax() = array->GetMinMax();
+ fakequant_op->minmax.reset(new MinMax);
+ *fakequant_op->minmax = array->GetMinMax();
+ if (must_insert_fakequant_before) {
+ for (const auto& op : model->operators) {
+ for (string& output : op->outputs) {
+ if (output == array_name) {
+ output = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {new_array_name};
+ fakequant_op->outputs = {array_name};
+ } else {
+ for (const auto& op : model->operators) {
+ for (string& input : op->inputs) {
+ if (input == array_name) {
+ input = new_array_name;
+ }
+ }
+ }
+ fakequant_op->inputs = {array_name};
+ fakequant_op->outputs = {new_array_name};
+ }
+ return true;
+}
+
+} // namespace
+
+bool Dequantize::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ auto* op = op_it->get();
+
+ if (op->type == OperatorType::kDequantize) {
+ auto& input_array = model->GetArray(op->inputs[0]);
+ if (input_array.data_type == ArrayDataType::kFloat) {
+ return false;
+ }
+ if (input_array.final_data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ input_array.data_type = ArrayDataType::kFloat;
+ input_array.quantization_params = nullptr;
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.data_type = ArrayDataType::kFloat;
+ output_array.quantization_params = nullptr;
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+
+ std::vector<string> arrays;
+ for (const string& input : op->inputs) {
+ arrays.push_back(input);
+ }
+ for (const string& output : op->outputs) {
+ arrays.push_back(output);
+ }
+ bool changed = false;
+ for (const string& array : arrays) {
+ changed |= DequantizeArray(array, this, model);
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
new file mode 100644
index 0000000000..fea360740f
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_fake_quant.cc
@@ -0,0 +1,56 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ const auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ if (!output_array.minmax) {
+ return false;
+ }
+
+ // Drop min/max inputs
+ for (int i = 1; i < fakequant_op->inputs.size(); i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
new file mode 100644
index 0000000000..a3ed6663bc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/drop_im2col_arrays.cc
@@ -0,0 +1,42 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool DropIm2colArrays::Run(Model* model, std::size_t op_index) {
+ auto conv_it = model->operators.begin() + op_index;
+ if (conv_it->get()->type != OperatorType::kConv) {
+ return false;
+ }
+ auto* conv_op = static_cast<ConvOperator*>(conv_it->get());
+ if (conv_op->outputs.size() < 2) {
+ // Conv op does not have im2col.
+ return false;
+ }
+
+ // Drop the im2col array.
+ CHECK_EQ(conv_op->outputs.size(), 2);
+ model->arrays.erase(conv_op->outputs[1]);
+ conv_op->outputs.resize(1);
+ AddMessageF("Dropped an im2col array for %s", LogName(*conv_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
new file mode 100644
index 0000000000..badefeca88
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/ensure_bias_vectors.cc
@@ -0,0 +1,57 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ProcessLinearOperator(Model* model, Operator* op) {
+ if (op->inputs.size() >= 3) {
+ return false;
+ }
+ const string& output_name = op->outputs[0];
+ const string& bias_name = AvailableArrayName(*model, output_name + "_bias");
+ op->inputs.push_back(bias_name);
+ DCHECK_EQ(op->inputs.size(), 3);
+ auto& bias_array = model->GetOrCreateArray(bias_name);
+ bias_array.data_type = ArrayDataType::kFloat;
+
+ return true;
+}
+} // namespace
+
+bool EnsureBiasVectors::Run(Model* model, std::size_t op_index) {
+ auto* op = model->operators[op_index].get();
+ if (op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected) {
+ if (ProcessLinearOperator(model, op)) {
+ AddMessageF("Added bias vector to %s", LogName(*op));
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
new file mode 100644
index 0000000000..7a86510025
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -0,0 +1,98 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto ac_it = model->operators.begin() + op_index;
+ const auto* ac_op = ac_it->get();
+
+ if (ac_op->type != OperatorType::kRelu6 &&
+ ac_op->type != OperatorType::kRelu1 &&
+ ac_op->type != OperatorType::kRelu) {
+ return false;
+ }
+
+ // Find the op producing the array passed to this activation function
+ Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
+
+ if (!op) return false;
+
+ if (CountTrueOutputs(*model, *op) > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it has more than one "
+ " consumed output",
+ LogName(*op));
+ return false;
+ }
+
+ CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
+
+ int count_ops_consuming_output = CountOpsWithInput(*model, ac_op->inputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing activation function into %s because it is consumed by more "
+ "than 1 other operator",
+ LogName(*op));
+ return false;
+ }
+
+ if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing activation function into %s because it already has a fused "
+ "activation function",
+ LogName(*op));
+ return false;
+ }
+
+ // TODO(dkalenichenko): Great many ops don't support activation function
+ // fusing. Switch to the whilelist approach instead.
+ if (op->type == OperatorType::kConcatenation ||
+ op->type == OperatorType::kSlice) {
+ AddMessageF(
+ "Not fusing activation function because the %s op doesn't support it",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Fusing activation function %s into the preceding %s",
+ LogName(*ac_op), LogName(*op));
+ if (ac_op->type == OperatorType::kRelu6) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu6;
+ } else if (ac_op->type == OperatorType::kRelu1) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu1;
+ } else if (ac_op->type == OperatorType::kRelu) {
+ op->fused_activation_function = FusedActivationFunctionType::kRelu;
+ } else {
+ LOG(FATAL) << "Unhandled activation function type";
+ }
+ model->arrays.erase(ac_op->inputs[0]);
+ op->outputs[0] = ac_op->outputs[0];
+ model->operators.erase(ac_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
new file mode 100644
index 0000000000..4619d8bbee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc
@@ -0,0 +1,300 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a subtraction, the constant input should be the right hand
+ // side.
+ // This should have been checked before this point.
+ CHECK(add_or_sub_op->type != OperatorType::kSub ||
+ index_of_constant_input == 1);
+ if (following_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ auto& bias = model->GetArray(following_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // At this point we reduce the case of subtraction to that of addition
+ // by negating the operand.
+ float add_scalar_operand = 0.f;
+ if (add_or_sub_op->type == OperatorType::kAdd) {
+ add_scalar_operand = scalar_operand;
+ } else if (add_or_sub_op->type == OperatorType::kSub &&
+ index_of_constant_input == 1) {
+ add_scalar_operand = -scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ // From here on we are fusing an addition. add_or_sub_op->type does not
+ // matter anymore.
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const auto& weights_buffer = weights.GetBuffer<ArrayDataType::kFloat>();
+ const float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+
+ if (following_op->type == OperatorType::kConv ||
+ following_op->type == OperatorType::kFullyConnected) {
+ const int output_depth = weights_shape.dims(0);
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ CHECK_EQ(output_depth, bias_shape.dims(bias_shape.dimensions_count() - 1));
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int d = 0; d < output_depth; d++) {
+ float accumulation = 0;
+ for (int i = 0; i < weights_per_depth; i++) {
+ accumulation +=
+ add_scalar_operand * weights_data[d * weights_per_depth + i];
+ }
+ bias_data[d] += accumulation;
+ }
+ } else if (following_op->type == OperatorType::kDepthwiseConv) {
+ const int output_depth =
+ weights_shape.dims(weights_shape.dimensions_count() - 1);
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ for (int c = 0; c < output_depth; c++) {
+ float accumulation = 0;
+ for (int k = 0; k < weights_per_depth; k++) {
+ accumulation += add_scalar_operand * weights_data[k * output_depth + c];
+ }
+ bias_data[c] += accumulation;
+ }
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+}
+
+void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ const auto& weights_name = following_op->inputs[1];
+ const auto& bias_name = following_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+ // We're only supporting the case of a scalar operand. Should have
+ // been checked earlier.
+ CHECK_EQ(RequiredBufferSizeForShape(operand.shape()), 1);
+
+ const float scalar_operand =
+ operand.GetBuffer<ArrayDataType::kFloat>().data[0];
+
+ float* weights_data =
+ weights.GetMutableBuffer<ArrayDataType::kFloat>().data.data();
+ const int weights_size = RequiredBufferSizeForShape(weights.shape());
+ for (int i = 0; i < weights_size; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[i] *= scalar_operand;
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[i] /= scalar_operand;
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+}
+
+} // namespace
+
+bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ const auto& operand_shape =
+ model->GetArray(binary_op->inputs[index_of_constant_input]).shape();
+ for (const auto& dim : operand_shape.dims()) {
+ if (dim > 1) {
+ AddMessageF(
+ "Not fusing %s into the following affine op, because we only know "
+ "how to do so when the constant operand is a scalar",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF("Not fusing %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]);
+
+ if (!following_op) {
+ AddMessageF(
+ "Not fusing %s because it is not consumed by exactly one other op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ if (following_op->type != OperatorType::kConv &&
+ following_op->type != OperatorType::kFullyConnected &&
+ following_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the following %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ if (following_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not have a bias vector",
+ LogName(*following_op), LogName(*binary_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(following_op->inputs[1]);
+ const auto& bias = model->GetArray(following_op->inputs[2]);
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the following %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+
+ // Try to fuse the binary params into the following op's params
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (following_op->type == OperatorType::kConv) {
+ if (static_cast<ConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ if (following_op->type == OperatorType::kDepthwiseConv) {
+ if (static_cast<DepthwiseConvOperator*>(following_op)->padding.type !=
+ PaddingType::kValid) {
+ AddMessageF(
+ "Not fusing %s because the following %s does not use VALID padding",
+ LogName(*binary_op), LogName(*following_op));
+ return false;
+ }
+ }
+ FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoFollowingAffine(model, following_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ AddMessageF("Fusing %s into the following %s", LogName(*binary_op),
+ LogName(*following_op));
+
+ model->arrays.erase(binary_op->outputs[0]);
+ following_op->inputs[0] = binary_op->inputs[index_of_variable_input];
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
new file mode 100644
index 0000000000..8948653ec3
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_preceding_affine.cc
@@ -0,0 +1,326 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void FuseAddOrSubParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* add_or_sub_op,
+ int index_of_constant_input) {
+ CHECK(add_or_sub_op->type == OperatorType::kAdd ||
+ add_or_sub_op->type == OperatorType::kSub);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ auto& bias = model->GetArray(preceding_op->inputs[2]);
+ bias.minmax = nullptr;
+ const auto& operand =
+ model->GetArray(add_or_sub_op->inputs[index_of_constant_input]);
+
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // TODO(b/62904716): Bias array should become 1-D when padding removed.
+ const int depth = bias_shape.dims(bias_shape.dimensions_count() - 1);
+ CHECK_EQ(depth, operand_shape.dims(operand_shape.dimensions_count() - 1));
+
+ enum class OpType { BiasPlusOperand, BiasMinusOperand, OperandMinusBias };
+
+ const OpType optype = (add_or_sub_op->type == OperatorType::kAdd)
+ ? OpType::BiasPlusOperand
+ : (index_of_constant_input == 1)
+ ? OpType::BiasMinusOperand
+ : OpType::OperandMinusBias;
+
+ for (int i = 0; i < depth; i++) {
+ float& bias_val = bias_data[i];
+ const float operand_val = operand_data[i];
+ if (optype == OpType::BiasPlusOperand) {
+ bias_val += operand_val;
+ } else if (optype == OpType::BiasMinusOperand) {
+ bias_val -= operand_val;
+ } else if (optype == OpType::OperandMinusBias) {
+ bias_val = operand_val - bias_val;
+ } else {
+ LOG(FATAL) << "Should not get here.";
+ }
+ }
+}
+
+void FuseMulOrDivParamsIntoPrecedingAffine(Model* model, Operator* preceding_op,
+ const Operator* mul_or_div_op,
+ int index_of_constant_input) {
+ CHECK(mul_or_div_op->type == OperatorType::kMul ||
+ mul_or_div_op->type == OperatorType::kDiv);
+ CHECK(index_of_constant_input == 0 || index_of_constant_input == 1);
+ // If the op is a division, the constant input should be the right hand side.
+ // This should have been checked before this point.
+ CHECK(mul_or_div_op->type != OperatorType::kDiv ||
+ index_of_constant_input == 1);
+ if (preceding_op->inputs.size() < 3) {
+ LOG(FATAL) << "Missing bias parameter";
+ }
+ const auto& weights_name = preceding_op->inputs[1];
+ const auto& bias_name = preceding_op->inputs[2];
+ auto& weights = model->GetArray(weights_name);
+ DropMinMax(model, weights_name);
+ auto& bias = model->GetArray(bias_name);
+ DropMinMax(model, bias_name);
+ const auto& operand =
+ model->GetArray(mul_or_div_op->inputs[index_of_constant_input]);
+
+ const Shape& weights_shape = weights.shape();
+ const Shape& bias_shape = bias.shape();
+ const Shape& operand_shape = operand.shape();
+ auto& weights_buffer = weights.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const weights_data = weights_buffer.data.data();
+ auto& bias_buffer = bias.GetMutableBuffer<ArrayDataType::kFloat>();
+ float* const bias_data = bias_buffer.data.data();
+ const auto& operand_buffer = operand.GetBuffer<ArrayDataType::kFloat>();
+ const float* const operand_data = operand_buffer.data.data();
+
+ // We support broadcasting the operand along the depth dimension,
+ // when the operand's depth is 1.
+ int operand_channel_increment = 0;
+ if (operand_shape.dimensions_count() >= 1 &&
+ operand_shape.dims(operand_shape.dimensions_count() - 1) ==
+ bias_shape.dims(bias_shape.dimensions_count() - 1)) {
+ operand_channel_increment = 1;
+ } else if (operand_shape.dimensions_count() == 0 ||
+ operand_shape.dims(operand_shape.dimensions_count() - 1) == 1) {
+ operand_channel_increment = 0;
+ } else {
+ LOG(FATAL) << "Operand shape mismatch.";
+ }
+
+ int output_depth;
+
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ output_depth = weights_shape.dims(0);
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ output_depth = weights_shape.dims(weights_shape.dimensions_count() - 1);
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+
+ const int weights_size = RequiredBufferSizeForShape(weights_shape);
+ const int weights_per_depth = weights_size / output_depth;
+ CHECK_EQ(weights_size, weights_per_depth * output_depth);
+
+ int operand_channel = 0;
+ for (int c = 0; c < output_depth; c++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ bias_data[c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ bias_data[c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ if (preceding_op->type == OperatorType::kConv ||
+ preceding_op->type == OperatorType::kFullyConnected) {
+ for (int i = 0; i < weights_per_depth; i++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[c * weights_per_depth + i] *=
+ operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[c * weights_per_depth + i] /=
+ operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else if (preceding_op->type == OperatorType::kDepthwiseConv) {
+ for (int k = 0; k < weights_per_depth; k++) {
+ if (mul_or_div_op->type == OperatorType::kMul) {
+ weights_data[k * output_depth + c] *= operand_data[operand_channel];
+ } else if (mul_or_div_op->type == OperatorType::kDiv) {
+ weights_data[k * output_depth + c] /= operand_data[operand_channel];
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ }
+ } else {
+ LOG(FATAL) << "Should not get here";
+ }
+ operand_channel += operand_channel_increment;
+ }
+}
+} // namespace
+
+bool FuseBinaryIntoPrecedingAffine::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // We only can fuse an binary when the two operands break down as follows:
+ // 1. One operand is the (variable) output of a typical affine (linear plus
+ // bias)
+ // op of a finite list of possible types: at the moment Conv,
+ // DepthwiseConv and
+ // FullyConnected are supported.
+ // 2. The other operand is a constant param array.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can fuse into a constant.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // For division, we can only fuse if the denominator is constant.
+ if (binary_op->type == OperatorType::kDiv) {
+ if (index_of_constant_input != 1) {
+ AddMessageF("Not fusing %s because the denominator is not constant",
+ LogName(*binary_op));
+ return false;
+ }
+ }
+
+ Operator* preceding_op =
+ GetOpWithOutput(*model, binary_op->inputs[index_of_variable_input]);
+ if (!preceding_op) {
+ AddMessageF("Not fusing %s because it is not the output of another op",
+ LogName(*binary_op));
+ return false;
+ }
+
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (preceding_op->outputs[0] == output_array) {
+ return false;
+ }
+ }
+
+ if (preceding_op->type != OperatorType::kConv &&
+ preceding_op->type != OperatorType::kFullyConnected &&
+ preceding_op->type != OperatorType::kDepthwiseConv) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s is not of one of the supported "
+ "types",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a fused activation "
+ "function",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ if (preceding_op->inputs.size() < 3) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s does not have a bias vector",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ const auto& weights = model->GetArray(preceding_op->inputs[1]);
+ const auto& bias = model->GetArray(preceding_op->inputs[2]);
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ if (!bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has a non-constant bias "
+ "array",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ } else {
+ if (!weights.buffer || !bias.buffer) {
+ AddMessageF(
+ "Not fusing %s because the preceding %s has non-constant weights or "
+ "bias arrays",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+ }
+
+ int count_ops_consuming_output =
+ CountOpsWithInput(*model, preceding_op->outputs[0]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not fusing %s because the output of the preceding %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*preceding_op));
+ return false;
+ }
+
+ AddMessageF("Fusing %s into the preceding %s", LogName(*binary_op),
+ LogName(*preceding_op));
+
+ if (binary_op->type == OperatorType::kAdd ||
+ binary_op->type == OperatorType::kSub) {
+ FuseAddOrSubParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else if (binary_op->type == OperatorType::kMul ||
+ binary_op->type == OperatorType::kDiv) {
+ FuseMulOrDivParamsIntoPrecedingAffine(model, preceding_op, binary_op,
+ index_of_constant_input);
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+
+ model->arrays.erase(preceding_op->outputs[0]);
+ preceding_op->outputs[0] = binary_op->outputs[0];
+ preceding_op->fused_activation_function =
+ binary_op->fused_activation_function;
+ const auto& old_constant_param_name =
+ binary_op->inputs[index_of_constant_input];
+ CHECK(IsConstantParameterArray(*model, old_constant_param_name));
+ if (CountOpsWithInput(*model, old_constant_param_name) == 1) {
+ model->arrays.erase(old_constant_param_name);
+ }
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
new file mode 100644
index 0000000000..323fec6cf8
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.cc
@@ -0,0 +1,108 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void PrintModelStats(const string& label, const Model& model) {
+ int quantized_arrays = 0;
+ for (const auto& array : model.arrays) {
+ if (array.second->quantization_params) {
+ quantized_arrays++;
+ }
+ }
+ LOG(INFO) << label << ": " << model.operators.size() << " operators, "
+ << model.arrays.size() << " arrays (" << quantized_arrays
+ << " quantized)";
+}
+
+bool GraphTransformationsPass(int increment, Model* model,
+ const GraphTransformationsSet& transformations) {
+ CHECK(increment == 1 || increment == -1);
+ bool changed = false;
+ CHECK(!model->operators.empty());
+ int op_index = increment == 1 ? 0 : model->operators.size() - 1;
+ while (true) {
+ bool changed_now = false;
+ // Loop over all transformations at the current position in the graph.
+ for (const auto& transformation : transformations) {
+ CHECK(!changed_now);
+ CHECK(transformation->Messages().empty());
+ changed_now = transformation->Run(model, op_index);
+ if (changed_now) {
+ DumpGraphvizVideoFrame(*model);
+ CHECK(!model->operators.empty());
+ op_index = std::min<int>(op_index, model->operators.size() - 1);
+ // Uncomment for debugging
+ // CheckInvariants(*model);
+ }
+ const char* made_a_change_msg =
+ changed_now ? "made a change" : "did NOT make a change";
+ const int log_level =
+ changed_now ? kLogLevelModelChanged : kLogLevelModelUnchanged;
+ for (const string& message : transformation->Messages()) {
+ VLOG(log_level) << transformation->Name() << " " << made_a_change_msg
+ << " at op_index=" << op_index << "/"
+ << model->operators.size() - 1 << ": " << message;
+ }
+ transformation->ClearMessages();
+ if (changed_now) {
+ break;
+ }
+ }
+ if (changed_now) {
+ changed = true;
+ } else {
+ const int op_index_last =
+ increment == 1 ? model->operators.size() - 1 : 0;
+ if (op_index == op_index_last) {
+ break;
+ }
+ op_index += increment;
+ }
+ }
+ return changed;
+}
+
+} // namespace
+
+void RunGraphTransformations(Model* model, const string& msg,
+ const GraphTransformationsSet& transformations) {
+ PrintModelStats(toco::port::StringF("Before %s", msg), *model);
+ int pass_index = 0;
+ while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
+ transformations)) {
+ pass_index++;
+ const auto& label =
+ toco::port::StringF("After %s pass %d", msg, pass_index);
+ PrintModelStats(label, *model);
+ CheckInvariants(*model);
+ }
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
new file mode 100644
index 0000000000..2cc24ff361
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
@@ -0,0 +1,186 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
+
+#include <cstddef>
+#include <initializer_list>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/toco_port.h"
+
+namespace toco {
+
+class GraphTransformation {
+ public:
+ virtual bool Run(Model* model, std::size_t op_index) = 0;
+ virtual const char* Name() const = 0;
+ virtual ~GraphTransformation() {}
+ // Returns the list of messages that this graph transformation
+ // generated since ClearMessages() was called.
+ const std::vector<string>& Messages() const { return messages_; }
+ // Clears the list of messages; should be called after every
+ // run of this graph transformation.
+ void ClearMessages() { return messages_.clear(); }
+ // Adds a message; normally only called by the graph transformation
+ // itself during its run (this function could be protected).
+ template <typename... Args>
+ void AddMessageF(const char* format, const Args&... args) {
+ return messages_.push_back(toco::port::StringF(format, args...));
+ }
+
+ protected:
+ GraphTransformation() {}
+
+ // List of messages generated by this graph transformation.
+ std::vector<string> messages_;
+
+ private:
+ GraphTransformation(const GraphTransformation& other) = delete;
+ GraphTransformation(const GraphTransformation&& other) = delete;
+};
+
+class GraphTransformationsSet {
+ public:
+ // The choice of a container with fully-specified iteration order
+ // ensures that graph transformations are always run in the same order,
+ // which avoids having toco randomly fail or produce different results
+ // depending on the toolchain. Ideally success/results should be independent
+ // of the order in which graph transformations are run, but that's
+ // unfortunately not currently guaranteed to be the case.
+ using TransformationsContainer =
+ std::vector<std::unique_ptr<GraphTransformation>>;
+
+ GraphTransformationsSet() {}
+ GraphTransformationsSet(
+ const std::initializer_list<GraphTransformation*> transformations) {
+ for (GraphTransformation* t : transformations) {
+ Add(t);
+ }
+ }
+ void Add(GraphTransformation* transformation) {
+ const string& name = transformation->Name();
+ CHECK(!names_.count(name));
+ names_.insert(name);
+ transformations_.emplace_back(transformation);
+ }
+ TransformationsContainer::const_iterator begin() const {
+ return transformations_.begin();
+ }
+ TransformationsContainer::const_iterator end() const {
+ return transformations_.end();
+ }
+ bool empty() const { return transformations_.empty(); }
+
+ private:
+ GraphTransformationsSet(const GraphTransformationsSet& other) = delete;
+ GraphTransformationsSet(const GraphTransformationsSet&& other) = delete;
+ std::vector<std::unique_ptr<GraphTransformation>> transformations_;
+ // Names of transformations in the set. Only used to guard against dupes.
+ std::unordered_set<string> names_;
+};
+
+// Run the given list of graph transformations on the model.
+// The message is only for logging purposes.
+// The transformations is a rvalue reference, indicating that
+// nothing else will use these pointers. The user is supposed to
+// construct GraphTransformation objects by using 'new', pass us
+// the resulting raw pointers, and this RunGraphTransformations
+// takes care of delete'ing these pointers.
+void RunGraphTransformations(Model* model, const string& message,
+ const GraphTransformationsSet& transformations);
+
+#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
+ class GTName : public GraphTransformation { \
+ public: \
+ bool Run(Model* model, std::size_t op_index) override; \
+ const char* Name() const { return #GTName; } \
+ };
+
+// List of all graph transformations
+DECLARE_GRAPH_TRANSFORMATION(ConvertPureConvToDepthwise)
+DECLARE_GRAPH_TRANSFORMATION(EnsureBiasVectors)
+DECLARE_GRAPH_TRANSFORMATION(FuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoFollowingAffine)
+DECLARE_GRAPH_TRANSFORMATION(FuseBinaryIntoPrecedingAffine)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Normalization)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyL2Pool)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell)
+DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1)
+DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator)
+DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes)
+DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes)
+DECLARE_GRAPH_TRANSFORMATION(HardcodeMinMax)
+DECLARE_GRAPH_TRANSFORMATION(Quantize)
+DECLARE_GRAPH_TRANSFORMATION(RemoveFinalDequantizeOp)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowAssert)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTensorFlowIdentity)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialConcatenationInput)
+DECLARE_GRAPH_TRANSFORMATION(RemoveTrivialQuantizedActivationFunc)
+DECLARE_GRAPH_TRANSFORMATION(RemoveUnusedOp)
+DECLARE_GRAPH_TRANSFORMATION(ResolveBatchNormalization)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantBinaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantUnaryOperator)
+DECLARE_GRAPH_TRANSFORMATION(CreateIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(DropIm2colArrays)
+DECLARE_GRAPH_TRANSFORMATION(ReadFakeQuantMinMax)
+DECLARE_GRAPH_TRANSFORMATION(ResolveReorderAxes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowConcat)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMatMul)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowMerge)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSqueeze)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowSwitch)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTensorFlowTile)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantConcatenation)
+DECLARE_GRAPH_TRANSFORMATION(DropFakeQuant)
+DECLARE_GRAPH_TRANSFORMATION(UnfuseActivationFunctions)
+DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
+DECLARE_GRAPH_TRANSFORMATION(Dequantize)
+
+class ResolveReshapeAttributes : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "ResolveReshapeAttributes"; }
+};
+
+class RemoveTrivialReshape : public GraphTransformation {
+ public:
+ bool Run(Model* model, std::size_t op_index) override;
+ const char* Name() const override { return "RemoveTrivialReshape"; }
+ bool treat_expand_dims_as_trivial() const {
+ return treat_expand_dims_as_trivial_;
+ }
+ void set_treat_expand_dims_as_trivial(bool val) {
+ treat_expand_dims_as_trivial_ = val;
+ }
+
+ private:
+ bool treat_expand_dims_as_trivial_ = false;
+};
+
+#undef DECLARE_GRAPH_TRANSFORMATION
+
+} // end namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_GRAPH_TRANSFORMATIONS_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
new file mode 100644
index 0000000000..d44b5dc7b0
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc
@@ -0,0 +1,229 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool HardcodeMinMaxForIm2colArray(Model* model, Operator* op) {
+ if (op->outputs.size() != 2) {
+ return false;
+ }
+ auto& im2col_array = model->GetArray(op->outputs[1]);
+ if (im2col_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!im2col_array.minmax);
+ auto& im2col_minmax = im2col_array.GetOrCreateMinMax();
+ im2col_minmax.min = input_minmax.min;
+ im2col_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForL2Normalization(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min >= 0. ? 0. : -1.;
+ output_minmax.max = input_minmax.max <= 0. ? 0. : 1.;
+ return true;
+}
+
+bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) {
+ // Do not early return if the output already has min/max:
+ // we may still need to adjust the inputs min/max.
+ bool has_minmax = false;
+ double overall_min = std::numeric_limits<double>::infinity();
+ double overall_max = -std::numeric_limits<double>::infinity();
+ for (const auto& input : op->inputs) {
+ if (model->GetArray(input).minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(input).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ }
+ auto& output = model->GetArray(op->outputs[0]);
+ if (output.minmax) {
+ has_minmax = true;
+ const auto* minmax = model->GetArray(op->outputs[0]).minmax.get();
+ if (minmax) {
+ overall_min = std::min(overall_min, minmax->min);
+ overall_max = std::max(overall_max, minmax->max);
+ }
+ }
+ if (!has_minmax) {
+ return false;
+ }
+ MinMax overall_minmax;
+ overall_minmax.min = overall_min;
+ overall_minmax.max = overall_max;
+ bool changed = false;
+ for (const auto& input : op->inputs) {
+ auto& array = model->GetArray(input);
+ if (!array.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == array.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of array " << input << ", which is "
+ << "an input to " << LogName(*op) << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no "
+ "arithmetic.";
+ }
+ array.GetOrCreateMinMax() = overall_minmax;
+ }
+ if (!output.minmax) {
+ changed = true;
+ } else if (!(overall_minmax == output.GetMinMax())) {
+ changed = true;
+ LOG(WARNING)
+ << "Tweaking the MinMax of the output array of " << LogName(*op)
+ << ", because we want all inputs "
+ << "and outputs of a Concatenation operator to have the same MinMax "
+ << "so that it can be implemented as a pure byte-copy, no arithmetic.";
+ }
+ output.GetOrCreateMinMax() = overall_minmax;
+
+ return changed;
+}
+
+// The output of average or max pooling is within the same range as its input.
+bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = std::min(input_minmax.min, 0.);
+ output_minmax.max = std::max(input_minmax.max, 0.);
+ return true;
+}
+
+bool HardcodeMinMaxForReshape(Model* model, Operator* op) {
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ const auto& input_minmax = input_array.GetMinMax();
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = input_minmax.min;
+ output_minmax.max = input_minmax.max;
+ return true;
+}
+
+bool HardcodeMinMaxForOutput(Model* model, Operator* op, double min,
+ double max) {
+ CHECK_EQ(op->outputs.size(), 1);
+ auto& output_array = model->GetArray(op->outputs[0]);
+ if (output_array.minmax) {
+ return false;
+ }
+ const auto& input_array = model->GetArray(op->inputs[0]);
+ if (!input_array.minmax) {
+ return false;
+ }
+ CHECK(!output_array.minmax);
+ auto& output_minmax = output_array.GetOrCreateMinMax();
+ output_minmax.min = min;
+ output_minmax.max = max;
+ return true;
+}
+} // namespace
+
+bool HardcodeMinMax::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ bool changed = false;
+ switch (op->type) {
+ case OperatorType::kConv:
+ changed = HardcodeMinMaxForIm2colArray(model, op);
+ break;
+
+ case OperatorType::kL2Normalization:
+ changed = HardcodeMinMaxForL2Normalization(model, op);
+ break;
+
+ case OperatorType::kConcatenation:
+ changed = HardcodeMinMaxForConcatenation(model, op);
+ break;
+
+ case OperatorType::kAveragePool:
+ case OperatorType::kMaxPool:
+ changed = HardcodeMinMaxForAverageOrMaxPool(model, op);
+ break;
+
+ case OperatorType::kTensorFlowReshape:
+ changed = HardcodeMinMaxForReshape(model, op);
+ break;
+
+ case OperatorType::kLogistic:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ case OperatorType::kSoftmax:
+ // We hardcode quantization_params to: zero_point=0, scale=1/256.
+ // This choice of minmax is the one that is equivalent to that.
+ changed = HardcodeMinMaxForOutput(model, op, 0, 255. / 256.);
+ break;
+
+ default:
+ break;
+ }
+ if (changed) {
+ AddMessageF("Hardcoded min-max through %s", LogName(*op));
+ }
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
new file mode 100644
index 0000000000..01b75e37c6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -0,0 +1,170 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
+ const auto div_it = model->operators.begin() + op_index;
+ const auto* div_or_mul_op = div_it->get();
+ OperatorType expected_op_type_producing_div_or_mul_input;
+ if (div_or_mul_op->type == OperatorType::kDiv) {
+ expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ } else if (div_or_mul_op->type == OperatorType::kMul) {
+ expected_op_type_producing_div_or_mul_input =
+ OperatorType::kTensorFlowRsqrt;
+ } else {
+ return false;
+ }
+ CHECK_EQ(div_or_mul_op->inputs.size(), 2);
+ Operator* op_producing_div_or_mul_input[2] = {
+ GetOpWithOutput(*model, div_or_mul_op->inputs[0]),
+ GetOpWithOutput(*model, div_or_mul_op->inputs[1]),
+ };
+ if (!op_producing_div_or_mul_input[1] ||
+ op_producing_div_or_mul_input[1]->type !=
+ expected_op_type_producing_div_or_mul_input) {
+ return false;
+ }
+ Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
+ CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
+ Operator* op_producing_sqrt_or_rsqrt_input =
+ GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
+ if (!op_producing_sqrt_or_rsqrt_input) {
+ return false;
+ }
+
+ // There may be an Add or a Maximum here, adding or clamping to a "small"
+ // constant scalar.
+ // Reported bug: b/29395854
+ Operator* add_op = nullptr;
+ Operator* op_producing_add_input = nullptr;
+ if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
+ op_producing_sqrt_or_rsqrt_input->type ==
+ OperatorType::kTensorFlowMaximum) {
+ add_op = op_producing_sqrt_or_rsqrt_input;
+ bool add_can_be_removed = false;
+ CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
+ for (int i = 0; i < 2; i++) {
+ const auto& input_array =
+ model->GetArray(op_producing_sqrt_or_rsqrt_input->inputs[i]);
+ if (!input_array.buffer) {
+ continue;
+ }
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ continue;
+ }
+ if (RequiredBufferSizeForShape(input_array.shape()) != 1) {
+ continue;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ if (std::abs(input_float_data[0]) > 1e-3f) {
+ continue;
+ }
+ add_can_be_removed = true;
+ op_producing_add_input = GetOpWithOutput(*model, add_op->inputs[1 - i]);
+ break;
+ }
+ if (!add_can_be_removed) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph "
+ " because the operator producing the input to the square root, %s,"
+ ", does not match the expected pattern",
+ LogName(*op_producing_sqrt_or_rsqrt_input));
+ return false;
+ }
+ }
+
+ Operator* sum_op =
+ add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
+ if (sum_op->type != OperatorType::kTensorFlowSum) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Sum op, got %s",
+ LogName(*sum_op));
+ return false;
+ }
+
+ Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ CHECK_EQ(square_op->inputs.size(), 1);
+
+ if (square_op->inputs[0] != div_or_mul_op->inputs[0]) {
+ AddMessageF(
+ "Giving up trying to identify L2Normalization subgraph: %s does not "
+ "take the same input as the Mul/Div node",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace the new L2Normalization
+ auto* l2norm_op = new L2NormalizationOperator;
+ l2norm_op->inputs = {div_or_mul_op->inputs[0]};
+ l2norm_op->outputs = div_or_mul_op->outputs;
+ model->operators.emplace(div_it, l2norm_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2norm_op));
+
+ // Erase the subgraph that is now replaced by L2Normalization
+ model->operators.erase(FindOperator(model, square_op));
+ model->arrays.erase(sum_op->inputs[0]);
+ if (sum_op->inputs.size() > 1) {
+ model->arrays.erase(sum_op->inputs[1]);
+ }
+ model->operators.erase(FindOperator(model, sum_op));
+ if (add_op) {
+ model->arrays.erase(add_op->inputs[0]);
+ model->arrays.erase(add_op->inputs[1]);
+ model->operators.erase(FindOperator(model, add_op));
+ }
+ model->arrays.erase(sqrt_or_rsqrt_op->inputs[0]);
+ model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
+ model->arrays.erase(div_or_mul_op->inputs[1]);
+ model->operators.erase(FindOperator(model, div_or_mul_op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
new file mode 100644
index 0000000000..1865416fc2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_pool.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+} // namespace
+
+bool IdentifyL2Pool::Run(Model* model, std::size_t op_index) {
+ const auto sqrt_it = model->operators.begin() + op_index;
+ const auto* sqrt_op = sqrt_it->get();
+ if (sqrt_op->type != OperatorType::kTensorFlowSqrt) {
+ return false;
+ }
+
+ CHECK_EQ(sqrt_op->inputs.size(), 1);
+ CHECK_EQ(sqrt_op->outputs.size(), 1);
+
+ const AveragePoolOperator* avpool_op;
+ const Operator* square_op;
+
+ Operator* prev_to_sqrt_op = GetOpWithOutput(*model, sqrt_op->inputs[0]);
+ if (prev_to_sqrt_op->type != OperatorType::kAveragePool) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected AveragePool op, got %s",
+ LogName(*prev_to_sqrt_op));
+ return false;
+ }
+
+ avpool_op = static_cast<const AveragePoolOperator*>(prev_to_sqrt_op);
+ CHECK_EQ(avpool_op->inputs.size(), 1);
+
+ square_op = GetOpWithOutput(*model, avpool_op->inputs[0]);
+ CHECK_EQ(square_op->inputs.size(), 1);
+ if (square_op->type != OperatorType::kTensorFlowSquare) {
+ AddMessageF(
+ "Giving up trying to identify L2Pool subgraph: "
+ "expected Square op, got %s",
+ LogName(*square_op));
+ return false;
+ }
+
+ // Create and emplace L2Pool node.
+ auto* l2pool_op = new L2PoolOperator;
+
+ l2pool_op->inputs = {square_op->inputs[0]};
+ l2pool_op->outputs = sqrt_op->outputs;
+
+ l2pool_op->padding.type = avpool_op->padding.type;
+ // Note that we do not setup avpool_op->padding.fixed here. This is done by
+ // the PropagateFixedSizes graph transformation.
+
+ l2pool_op->stride_height = avpool_op->stride_height;
+ l2pool_op->stride_width = avpool_op->stride_width;
+ l2pool_op->kheight = avpool_op->kheight;
+ l2pool_op->kwidth = avpool_op->kwidth;
+ model->operators.emplace(sqrt_it, l2pool_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*l2pool_op));
+
+ // Erase intermediate arrays, keeping input to square op.
+ model->arrays.erase(avpool_op->inputs[0]);
+ model->arrays.erase(sqrt_op->inputs[0]);
+
+ // Erase three operators being replaced.
+ model->operators.erase(FindOperator(model, square_op));
+ model->operators.erase(FindOperator(model, avpool_op));
+ model->operators.erase(FindOperator(model, sqrt_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
new file mode 100644
index 0000000000..082820fddc
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_lstm.cc
@@ -0,0 +1,396 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "absl/strings/string_view.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator& op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == &op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool GetStateArrayForBackEdge(const Model& model,
+ const string& back_edge_source_array,
+ string* state_array = nullptr) {
+ for (const auto& rnn_state : model.flags.rnn_states()) {
+ if (back_edge_source_array == rnn_state.back_edge_source_array()) {
+ // Found LSTM cell output
+ if (state_array) {
+ *state_array = rnn_state.state_array();
+ }
+ return true;
+ }
+ }
+ return false;
+}
+
+// Returns true if the given operator has exactly 1 input, and is connected to
+// the given op_type.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType op_type, Operator** connected_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 1) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (connected_op) {
+ *connected_op = x;
+ }
+
+ return true;
+}
+
+// Returns true if the given operator has exactly 2 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 2) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ return true;
+}
+
+// Returns true if the given operator has exactly 3 inputs, which are connected
+// to the given op_types.
+// We use kNone to indicate an input unattached to an operator output. Usually
+// these are the static input arrays.
+bool MatchOperatorInputs(const Operator& op, const Model& model,
+ OperatorType a_op_type, Operator** a_op,
+ OperatorType b_op_type, Operator** b_op,
+ OperatorType c_op_type, Operator** c_op) {
+ // Check for required number of inputs
+ if (op.inputs.size() != 3) {
+ return false;
+ }
+
+ // Check if first input is disconnected/connected to an operator
+ Operator* x = GetOpWithOutput(model, op.inputs[0]);
+ if ((a_op_type == OperatorType::kNone) && (x != nullptr)) {
+ return false;
+ }
+ if ((a_op_type != OperatorType::kNone) && (x == nullptr)) {
+ return false;
+ }
+
+ // Check that first operator, if connected, is of correct type
+ if ((x != nullptr) && (x->type != a_op_type)) {
+ return false;
+ }
+
+ // Check if second input is disconnected/connected to an operator
+ Operator* y = GetOpWithOutput(model, op.inputs[1]);
+ if ((b_op_type == OperatorType::kNone) && (y != nullptr)) {
+ return false;
+ }
+ if ((b_op_type != OperatorType::kNone) && (y == nullptr)) {
+ return false;
+ }
+
+ // Check that second operator, if connected, is of correct type
+ if ((y != nullptr) && (y->type != b_op_type)) {
+ return false;
+ }
+
+ // Check if third input is disconnected/connected to an operator
+ Operator* z = GetOpWithOutput(model, op.inputs[2]);
+ if ((c_op_type == OperatorType::kNone) && (z != nullptr)) {
+ return false;
+ }
+ if ((c_op_type != OperatorType::kNone) && (z == nullptr)) {
+ return false;
+ }
+
+ // Check that third operator, if connected, is of correct type
+ if ((z != nullptr) && (z->type != c_op_type)) {
+ return false;
+ }
+
+ // Successfully matched. Optionally return matching input operators.
+ if (a_op != nullptr) {
+ *a_op = x;
+ }
+ if (b_op != nullptr) {
+ *b_op = y;
+ }
+ if (c_op != nullptr) {
+ *c_op = z;
+ }
+ return true;
+}
+
+absl::string_view FindLongestCommonPrefix(absl::string_view a,
+ absl::string_view b) {
+ if (a.empty() || b.empty()) return absl::string_view();
+
+ const char* pa = a.data();
+ const char* pb = b.data();
+ size_t count = 0;
+ const ssize_t limit = std::min(a.size(), b.size());
+ while (count < limit && *pa == *pb) {
+ ++pa;
+ ++pb;
+ ++count;
+ }
+
+ return absl::string_view(a.data(), count);
+}
+
+} // namespace
+
+bool IdentifyLstmCell::Run(Model* model, std::size_t op_index) {
+ // This LSTM cell identification method is not invariant to commutation of
+ // commutative operator inputs. For example, if input[0] and input[1] of the
+ // final output multiplication were swapped, this method would not identify it
+ // as an LSTM cell. This is OK in most cases, because
+ // tf.rnn.contrib.BasicLSTMCell always generates LSTM cells the same way.
+
+ // Final output multiply
+ auto op_it = model->operators.begin() + op_index;
+ Operator* final_output_mul = op_it->get();
+ if (final_output_mul->type != OperatorType::kMul) {
+ return false;
+ }
+ Operator *state_output_tanh, *fc_output_sig;
+ if (!MatchOperatorInputs(*final_output_mul, *model, OperatorType::kTanh,
+ &state_output_tanh, OperatorType::kLogistic,
+ &fc_output_sig)) {
+ return false;
+ }
+
+ // State output TanH
+ // (We don't count an operator as ID'd until we verify it has the correct
+ // operator types feeding into it.)
+ Operator* state_combine_add;
+ if (!MatchOperatorInputs(*state_output_tanh, *model, OperatorType::kAdd,
+ &state_combine_add)) {
+ return false;
+ }
+ string prev_state;
+ if (!GetStateArrayForBackEdge(*model, state_output_tanh->inputs[0],
+ &prev_state)) {
+ return false;
+ }
+
+ // State forget & remember addition
+ Operator *state_forget_mul, *state_remember_mul;
+ if (!MatchOperatorInputs(*state_combine_add, *model, OperatorType::kMul,
+ &state_forget_mul, OperatorType::kMul,
+ &state_remember_mul)) {
+ return false;
+ }
+ if (state_forget_mul->inputs[0] != prev_state) {
+ return false;
+ }
+
+ // State forget gate
+ Operator* state_forget_sig;
+ if (!MatchOperatorInputs(*state_forget_mul, *model, OperatorType::kNone,
+ nullptr, OperatorType::kLogistic,
+ &state_forget_sig)) {
+ return false;
+ }
+
+ // State remember gate
+ Operator *state_remember_sig, *state_info_tanh;
+ if (!MatchOperatorInputs(*state_remember_mul, *model, OperatorType::kLogistic,
+ &state_remember_sig, OperatorType::kTanh,
+ &state_info_tanh)) {
+ return false;
+ }
+
+ // State remember "information" activation function
+ Operator* fc_output_split;
+ if (!MatchOperatorInputs(*state_info_tanh, *model,
+ OperatorType::kTensorFlowSplit, &fc_output_split)) {
+ return false;
+ }
+ // State remember gate activation function
+ Operator* tmp;
+ if (!MatchOperatorInputs(*state_remember_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // State forget gate activation function
+ if (!MatchOperatorInputs(*state_forget_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output activation function
+ if (!MatchOperatorInputs(*fc_output_sig, *model,
+ OperatorType::kTensorFlowSplit, &tmp) ||
+ (tmp != fc_output_split)) {
+ return false;
+ }
+ // Fully connected output split
+ Operator* fully_connected;
+ if (!MatchOperatorInputs(*fc_output_split, *model, OperatorType::kNone,
+ nullptr, OperatorType::kFullyConnected,
+ &fully_connected)) {
+ return false;
+ }
+
+ // Fully connected op
+ Operator* concat_inputs;
+ if (!MatchOperatorInputs(*fully_connected, *model,
+ OperatorType::kConcatenation, &concat_inputs,
+ OperatorType::kNone, nullptr, OperatorType::kNone,
+ nullptr)) {
+ return false;
+ }
+
+ // Emplace a new LSTM cell operator
+ auto* lstm_cell_op = new LstmCellOperator;
+ lstm_cell_op->inputs.resize(LstmCellOperator::NUM_INPUTS);
+ lstm_cell_op->inputs[LstmCellOperator::DATA_INPUT] = concat_inputs->inputs[0];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_ACTIV_INPUT] =
+ concat_inputs->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::WEIGHTS_INPUT] =
+ fully_connected->inputs[1];
+ lstm_cell_op->inputs[LstmCellOperator::BIASES_INPUT] =
+ fully_connected->inputs[2];
+ lstm_cell_op->inputs[LstmCellOperator::PREV_STATE_INPUT] = prev_state;
+ lstm_cell_op->outputs.resize(LstmCellOperator::NUM_OUTPUTS);
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT] =
+ state_output_tanh->inputs[0];
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT] =
+ final_output_mul->outputs[0];
+ model->operators.emplace(op_it, lstm_cell_op);
+ AddMessageF("Creating %s replacing equivalent subgraph",
+ LogName(*lstm_cell_op));
+
+ // Create temp arrays used internally during runtime.
+ const string base_name(FindLongestCommonPrefix(
+ lstm_cell_op->outputs[LstmCellOperator::STATE_OUTPUT],
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_OUTPUT]));
+ const string& concat_temp_array_name =
+ AvailableArrayName(*model, base_name + "concat_temp");
+ model->GetOrCreateArray(concat_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::CONCAT_TEMP] = concat_temp_array_name;
+ const string& activ_temp_array_name =
+ AvailableArrayName(*model, base_name + "activ_temp");
+ model->GetOrCreateArray(activ_temp_array_name);
+ lstm_cell_op->outputs[LstmCellOperator::ACTIV_TEMP] = activ_temp_array_name;
+ AddMessageF("Created temp outputs %s and %s on operator %s",
+ concat_temp_array_name, activ_temp_array_name,
+ LogName(*lstm_cell_op));
+
+ // Delete arrays and operators replaced by the LSTM cell operator. Order is
+ // important - DeleteArrayIfUnused() only succeeds if dependent operators
+ // have been removed first. Start at the output and work towards the input.
+ model->operators.erase(FindOperator(model, *final_output_mul));
+ DeleteArrayIfUnused(state_output_tanh->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_output_tanh));
+ model->operators.erase(FindOperator(model, *fc_output_sig));
+ model->operators.erase(FindOperator(model, *state_combine_add));
+ DeleteArrayIfUnused(state_forget_mul->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_mul->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_mul));
+ model->operators.erase(FindOperator(model, *state_remember_mul));
+ DeleteArrayIfUnused(state_forget_sig->outputs[0], model);
+ DeleteArrayIfUnused(state_info_tanh->outputs[0], model);
+ DeleteArrayIfUnused(state_remember_sig->outputs[0], model);
+ model->operators.erase(FindOperator(model, *state_forget_sig));
+ model->operators.erase(FindOperator(model, *state_info_tanh));
+ model->operators.erase(FindOperator(model, *state_remember_sig));
+ DeleteArrayIfUnused(fc_output_split->outputs[0], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[1], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[2], model);
+ DeleteArrayIfUnused(fc_output_split->outputs[3], model);
+ string dims_array = fc_output_split->inputs[0];
+ model->operators.erase(FindOperator(model, *fc_output_split));
+ DeleteArrayIfUnused(dims_array, model);
+ DeleteArrayIfUnused(fully_connected->outputs[0], model);
+ model->operators.erase(FindOperator(model, *fully_connected));
+ DeleteArrayIfUnused(concat_inputs->outputs[0], model);
+ model->operators.erase(FindOperator(model, *concat_inputs));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
new file mode 100644
index 0000000000..cfc77024e7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -0,0 +1,103 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
+ Model* model, const Operator* op) {
+ auto it = model->operators.begin();
+ for (; it != model->operators.end(); ++it) {
+ if (it->get() == op) {
+ break;
+ }
+ }
+ return it;
+}
+
+bool CheckArrayIsScalarFloat(Model* model, const std::string& name, float val) {
+ const auto& op_array = model->GetArray(name);
+ if (!op_array.buffer || op_array.buffer->type != ArrayDataType::kFloat ||
+ RequiredBufferSizeForShape(op_array.shape()) != 1) {
+ return false;
+ }
+ const auto& op_data = op_array.GetBuffer<ArrayDataType::kFloat>().data;
+ return op_data[0] == val;
+}
+
+// Returns index of scalar input when there is exactly one scalar, -1 otherwise
+int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
+ float val) {
+ bool input0_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[0], val);
+ bool input1_is_scalar = CheckArrayIsScalarFloat(model, op->inputs[1], val);
+ return input0_is_scalar == input1_is_scalar ? -1 : input0_is_scalar ? 0 : 1;
+}
+} // namespace
+
+bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
+ const auto maximum_it = model->operators.begin() + op_index;
+ const auto* maximum_op = maximum_it->get();
+ if (maximum_op->type != OperatorType::kTensorFlowMaximum) {
+ return false;
+ }
+ CHECK_EQ(maximum_op->inputs.size(), 2);
+ if (maximum_op->outputs.size() != 1) {
+ return false;
+ }
+ int scalar_input_index =
+ GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f);
+ if (scalar_input_index == -1) {
+ return false;
+ }
+ const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]);
+ if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) {
+ return false;
+ }
+ if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) {
+ return false;
+ }
+ CHECK_EQ(minimum_op->inputs.size(), 2);
+
+ // Create and emplace Relu1 node
+ auto* relu1_op = new Relu1Operator;
+ relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]};
+ relu1_op->outputs = minimum_op->outputs;
+ model->operators.emplace(maximum_it, relu1_op);
+
+ AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op));
+
+ // Erase Maximum scalar input & operator
+ model->arrays.erase(maximum_op->inputs[scalar_input_index]);
+ model->operators.erase(FindOperator(model, maximum_op));
+
+ // Erase Minimum inputs & operator
+ model->arrays.erase(minimum_op->inputs[0]);
+ model->arrays.erase(minimum_op->inputs[1]);
+ model->operators.erase(FindOperator(model, minimum_op));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
new file mode 100644
index 0000000000..d83603e9a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/make_initial_dequantize_operator.cc
@@ -0,0 +1,120 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+// This inserts an operator whose output is a float array (name:
+// flags.input_array()). It has to wait for any existing operators that
+// generate this output to be removed by graph transformations. Note that there
+// may be more than one operator that takes the input_array as their input, and
+// that some of these may be removed by graph transformations.
+bool AddDequantizeOperatorToInput(const string& input_name, const Operator* op,
+ GraphTransformation* transformation,
+ Model* model) {
+ // An operator with the required output may be a dequantize operator already
+ // created. Alternatively it may be an operator that needs to be removed
+ // because it is unused, in which case we wait for RemoveUnusedOp to do its
+ // work.
+ if (GetOpWithOutput(*model, input_name)) {
+ return false;
+ }
+
+ // We only apply for the first operator if there is more than one. This is
+ // not strictly necessary for ordering correctness, since we insert the
+ // dequant operator at the beginning of the op sequence, but it makes the
+ // insertion more predictable (eg forward vs backwards operator sweep).
+ if (CountOpsWithInput(*model, input_name) > 1) {
+ if (op != GetFirstOpWithInput(*model, input_name)) {
+ return false;
+ }
+ }
+
+ auto& input_array = model->GetArray(input_name);
+ if (input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+
+ if (input_array.final_data_type == input_array.data_type ||
+ input_array.final_data_type == ArrayDataType::kNone) {
+ return false;
+ }
+
+ const auto& dequantized_input_name =
+ AvailableArrayName(*model, input_name + "_dequantized");
+ for (auto& other_op : model->operators) {
+ for (string& other_op_input : other_op->inputs) {
+ if (other_op_input == input_name) {
+ other_op_input = dequantized_input_name;
+ }
+ }
+ }
+
+ auto& dequantized_input_array =
+ model->GetOrCreateArray(dequantized_input_name);
+ auto* image_input_op = new DequantizeOperator;
+ image_input_op->inputs = {input_name};
+ image_input_op->outputs = {dequantized_input_name};
+ model->operators.emplace(model->operators.begin(), image_input_op);
+
+ CHECK(input_array.final_data_type == ArrayDataType::kUint8);
+ input_array.data_type = ArrayDataType::kUint8;
+ dequantized_input_array.data_type = ArrayDataType::kFloat;
+ const auto& input_minmax = input_array.GetMinMax();
+ auto& dequantized_input_minmax = dequantized_input_array.GetOrCreateMinMax();
+ dequantized_input_minmax = input_minmax;
+ auto& input_qparams = input_array.GetOrCreateQuantizationParams();
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, input_minmax, &input_qparams);
+
+ transformation->AddMessageF(
+ "Created %s"
+ " to handle quantized input image data, taking over existing"
+ " mean_value and std_value flags. Cleared those flags.",
+ LogName(*image_input_op));
+
+ return true;
+}
+
+bool MakeInitialDequantizeOperator::Run(Model* model, std::size_t op_index) {
+ // This is effectively a transformation applied to edges. We iterate over the
+ // specified node (op) and proceed for input edges.
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+ bool change_made = false;
+ for (auto& input : op->inputs) {
+ for (auto& input_array : *model->flags.mutable_input_arrays()) {
+ if (input_array.name() == input) {
+ if (AddDequantizeOperatorToInput(input_array.name(), op, this, model)) {
+ change_made = true;
+ input_array.clear_mean_value();
+ input_array.clear_std_value();
+ }
+ }
+ }
+ }
+ return change_made;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
new file mode 100644
index 0000000000..1ff4e827aa
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+ArrayDataType CommonDataTypeOfAllInputs(const Model& model,
+ const Operator& op) {
+ CHECK_GT(op.inputs.size(), 0);
+ const ArrayDataType data_type = model.GetArray(op.inputs[0]).data_type;
+ for (const auto& input : op.inputs) {
+ const auto& array = model.GetArray(input);
+ CHECK(array.data_type == data_type)
+ << " Unexpected: this operator has inputs with different data types.";
+ }
+ return data_type;
+}
+
+void SetDataTypeForAllOutputs(Model* model, Operator* op,
+ ArrayDataType data_type) {
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->data_type = data_type;
+ }
+}
+} // namespace
+
+bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If the data type of some input is unknown, we need to yield.
+ for (const auto& input : op->inputs) {
+ if (model->arrays[input]->data_type == ArrayDataType::kNone) {
+ return false;
+ }
+ }
+ // Record data types of output before processing, so we can see at the
+ // end if we changed anything, and return the correct boolean value.
+ std::unordered_map<string, ArrayDataType> old_output_data_types;
+ for (const auto& output : op->outputs) {
+ old_output_data_types[output] = model->arrays[output]->data_type;
+ }
+ // Do the actual output data types propagation.
+ if (op->type == OperatorType::kDequantize ||
+ op->type == OperatorType::kResizeBilinear) {
+ // These operators unconditionally produce float outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kFloat);
+ } else if (op->type == OperatorType::kTensorFlowLess ||
+ op->type == OperatorType::kTensorFlowLessEqual ||
+ op->type == OperatorType::kTensorFlowGreater ||
+ op->type == OperatorType::kTensorFlowGreaterEqual) {
+ // These operators unconditionally produce bool outputs
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kBool);
+ } else if (op->type == OperatorType::kTensorFlowShape) {
+ // These operators are assumed to produce int32 outputs.
+ SetDataTypeForAllOutputs(model, op, ArrayDataType::kInt32);
+ } else if (op->type == OperatorType::kAveragePool ||
+ op->type == OperatorType::kMaxPool ||
+ op->type == OperatorType::kL2Pool ||
+ op->type == OperatorType::kConv ||
+ op->type == OperatorType::kDepthwiseConv ||
+ op->type == OperatorType::kFullyConnected ||
+ op->type == OperatorType::kTensorFlowMax ||
+ op->type == OperatorType::kTensorFlowMin ||
+ op->type == OperatorType::kPad ||
+ op->type == OperatorType::kStridedSlice ||
+ op->type == OperatorType::kTensorFlowReshape ||
+ op->type == OperatorType::kSlice ||
+ op->type == OperatorType::kSqueeze ||
+ op->type == OperatorType::kTensorFlowSum ||
+ op->type == OperatorType::kTensorFlowSwitch ||
+ op->type == OperatorType::kTensorFlowTile ||
+ op->type == OperatorType::kTensorFlowAll ||
+ op->type == OperatorType::kReorderAxes ||
+ op->type == OperatorType::kTensorFlowConcatV2 ||
+ op->type == OperatorType::kFloor ||
+ op->type == OperatorType::kGather ||
+ op->type == OperatorType::kSpaceToBatchND ||
+ op->type == OperatorType::kBatchToSpaceND ||
+ op->type == OperatorType::kMean) {
+ // These operators produce outputs with the same type as their 1st input
+ CHECK_GT(op->inputs.size(), 0);
+ const ArrayDataType data_type = model->arrays[op->inputs[0]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kTensorFlowSplit ||
+ op->type == OperatorType::kTensorFlowConcat) {
+ // These operators produce an output with the same type as their 2nd input
+ CHECK_GT(op->inputs.size(), 1);
+ const ArrayDataType data_type = model->arrays[op->inputs[1]]->data_type;
+ SetDataTypeForAllOutputs(model, op, data_type);
+ } else if (op->type == OperatorType::kCast) {
+ // Data type of the Cast op is specified.
+ CHECK_EQ(op->outputs.size(), 1);
+ auto* cast_op = static_cast<CastOperator*>(op);
+ model->arrays[op->outputs[0]]->data_type = cast_op->dst_data_type;
+ } else if (op->type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported_op = static_cast<TensorFlowUnsupportedOperator*>(op);
+ if (unsupported_op->output_data_types.size() != op->outputs.size()) {
+ return false;
+ }
+ for (int i = 0; i < unsupported_op->output_data_types.size(); ++i) {
+ auto output = op->outputs[i];
+ auto data_type = unsupported_op->output_data_types[i];
+ model->arrays[output]->data_type = data_type;
+ }
+ } else {
+ // These operators produce an output with the same type as any of their
+ // inputs, which must always have the same type.
+ const ArrayDataType data_type = CommonDataTypeOfAllInputs(*model, *op);
+ SetDataTypeForAllOutputs(model, op, data_type);
+ }
+ // Return true if any output data type changed, false if none changed.
+ for (const auto& output : op->outputs) {
+ if (old_output_data_types[output] != model->arrays[output]->data_type) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
new file mode 100644
index 0000000000..82a43bc2ce
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -0,0 +1,1129 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
+ int kheight, int stride_width, int stride_height,
+ PaddingType padding_type, Shape* output_shape,
+ FixedPadding* fixed_padding) {
+ const int input_width = input_shape.dims(2);
+ const int input_height = input_shape.dims(1);
+ const int batch = input_shape.dims(0);
+
+ int output_height = 0;
+ int output_width = 0;
+ if (padding_type == PaddingType::kValid) {
+ output_height = (input_height + stride_height - kheight) / stride_height;
+ output_width = (input_width + stride_width - kwidth) / stride_width;
+ } else if (padding_type == PaddingType::kSame) {
+ output_height = (input_height + stride_height - 1) / stride_height;
+ output_width = (input_width + stride_width - 1) / stride_width;
+ } else {
+ LOG(FATAL) << "Only supporting SAME or VALID padding";
+ }
+
+ fixed_padding->height =
+ ((output_height - 1) * stride_height + kheight - input_height) / 2;
+ fixed_padding->width =
+ ((output_width - 1) * stride_width + kwidth - input_width) / 2;
+
+ // Actually had to debug a situation where those were negative due to bad
+ // propagation of placeholder -1 sizes in TensorFlowReshape.
+ CHECK_GT(output_width, 0);
+ CHECK_GT(output_height, 0);
+ output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
+}
+
+void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
+ const Shape& input_shape2,
+ Array* output_array) {
+ const int size1 = RequiredBufferSizeForShape(input_shape1);
+ const int size2 = RequiredBufferSizeForShape(input_shape2);
+ if (size1 > size2) {
+ output_array->copy_shape(input_shape1);
+ } else if (size2 > size1) {
+ output_array->copy_shape(input_shape2);
+ } else {
+ CHECK_EQ(size1, size2);
+ const int dims1 = input_shape1.dimensions_count();
+ const int dims2 = input_shape2.dimensions_count();
+ if (dims1 >= dims2) {
+ output_array->copy_shape(input_shape1);
+ } else {
+ output_array->copy_shape(input_shape2);
+ }
+ }
+ CHECK(output_array->has_shape());
+}
+
+int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
+ const string& weights_name = op.inputs[1];
+ const auto& weights_shape = model.arrays.at(weights_name)->shape();
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kFullyConnected) {
+ return weights_shape.dims(0);
+ } else if (op.type == OperatorType::kDepthwiseConv) {
+ return weights_shape.dims(3);
+ } else {
+ LOG(FATAL) << "Unhandled operator type";
+ }
+}
+
+bool EnsureBiasVectorShape(Model* model, Operator* op) {
+ const string& weights_name = op->inputs[1];
+ const auto& weights_array = *model->arrays[weights_name];
+ // Yield until weights shape has been resolved.
+ if (!weights_array.has_shape()) {
+ return false;
+ }
+
+ if (op->inputs.size() < 3) {
+ return false;
+ }
+ auto& bias_array = *model->arrays[op->inputs[2]];
+ if (bias_array.has_shape()) {
+ return true;
+ }
+
+ const int output_depth = GetOutputDepthFromWeights(*model, *op);
+ bias_array.copy_shape(Shape({output_depth}));
+
+ auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ float_buffer.data.resize(output_depth, 0);
+
+ return true;
+}
+
+void ProcessConvOperator(Model* model, ConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ const int output_depth = weights_shape.dims(0);
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ output_array.mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+ CHECK_EQ(output_array.shape().dimensions_count(), 4);
+
+ // Set im2col array dimensions if there is one.
+ if (op->outputs.size() == 2) {
+ const auto& output_shape = output_array.shape();
+ const int input_depth = weights_shape.dims(3);
+ auto& im2col_array = *model->arrays[op->outputs[1]];
+ im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
+ output_shape.dims(2),
+ input_depth * kheight * kwidth});
+ }
+}
+
+void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int input_depth = input_shape.dims(3);
+ const int output_depth = weights_shape.dims(3);
+ // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
+ // instead it has to be inferred from the weights dims. However, once we are
+ // here, weights dims have already been converted to our own internal format,
+ // where the multiplier is no longer readily apparent. So instead we get it
+ // as the quotient of output and input depths. We only want to do that when
+ // depth_multiplier had the zero value: any other value should be checked
+ // as done by the next if() below.
+ if (!op->depth_multiplier) {
+ op->depth_multiplier = output_depth / input_depth;
+ }
+ QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
+ << "input/output depths and depth_multiplier don't match";
+
+ const int kheight = weights_shape.dims(1);
+ const int kwidth = weights_shape.dims(2);
+ ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
+ op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(depth % (block_size * block_size), 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height * block_size, width * block_size,
+ depth / block_size / block_size}));
+}
+
+void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+
+ const string& output_name = op->outputs[0];
+ const int block_size = op->block_size;
+ CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
+ const int batch = input_shape.dims(0);
+ const int height = input_shape.dims(1);
+ const int width = input_shape.dims(2);
+ const int depth = input_shape.dims(3);
+ QCHECK_EQ(width % block_size, 0);
+ QCHECK_EQ(height % block_size, 0);
+
+ model->GetArray(output_name)
+ .copy_shape(Shape({batch, height / block_size, width / block_size,
+ depth * block_size * block_size}));
+}
+
+void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
+ if (!EnsureBiasVectorShape(model, op)) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 1);
+
+ const auto& weights_array = *model->arrays[op->inputs[1]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+
+ const int weights_output_depth = weights_shape.dims(0);
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const int input_overall_size = RequiredBufferSizeForShape(input_shape);
+ const int matmul_repeats = input_overall_size / weights_shape.dims(1);
+ CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
+
+ auto& output_array = model->GetArray(op->outputs[0]);
+ output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
+}
+
+void ProcessTensorFlowReshapeOperator(Model* model,
+ TensorFlowReshapeOperator* op) {
+ auto& output_array = *model->arrays[op->outputs[0]];
+ // Bail if we already have output dims
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+
+ const string& shape_name = op->inputs[1];
+ auto& shape_array = model->GetArray(shape_name);
+ // Yield until the shape is resolved as a constant array
+ if (!shape_array.buffer) {
+ return;
+ }
+ CHECK(shape_array.data_type == ArrayDataType::kInt32);
+ // shape_data is the raw array of ints describing the shape
+ // in the TensorFlow node. We intentionally make a copy here, rather than
+ // modify wildcards in-place below, because in some graphs, the same shape
+ // array with a wildcard may be referenced from multiple Reshape nodes, where
+ // the wildcard needs to resolved to distinct values.
+ std::vector<int32> shape_data =
+ shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // The Reshape shape may have a wildcard dim, encoded as -1.
+ bool has_wildcard = false;
+ int wildcard_index = 0;
+ int product_non_wildcard_dims = 1;
+ for (int i = 0; i < shape_data.size(); i++) {
+ if (shape_data[i] == -1) {
+ CHECK(!has_wildcard);
+ has_wildcard = true;
+ wildcard_index = i;
+ } else {
+ product_non_wildcard_dims *= shape_data[i];
+ }
+ }
+ const int input_flat_size = RequiredBufferSizeForShape(input_shape);
+ if (has_wildcard) {
+ shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
+ }
+ auto& output_shape = *output_array.mutable_shape();
+ *output_shape.mutable_dims() = shape_data;
+ const int output_flat_size = RequiredBufferSizeForShape(output_shape);
+ CHECK_EQ(output_flat_size, input_flat_size);
+}
+
+void ProcessSimpleOperator(Model* model, Operator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ output_array.copy_shape(input_array.shape());
+}
+
+void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const auto& input0_array = *model->arrays[op->inputs[0]];
+ const auto& input1_array = *model->arrays[op->inputs[1]];
+ // Yield until input dims have been resolved.
+ if (!input0_array.has_shape() || !input1_array.has_shape()) {
+ return;
+ }
+ const string& output_name = op->outputs[0];
+ auto& output_array = *model->arrays[output_name];
+ ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
+ &output_array);
+}
+
+void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
+ CHECK_LE(op->inputs.size(), 2);
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) {
+ return;
+ }
+ if (op->inputs.size() == 2) {
+ // There is a reduction_indices input.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& reduction_array = *model->arrays[op->inputs[1]];
+ if (!reduction_array.buffer) {
+ return;
+ }
+ if (!input_array.has_shape()) {
+ return;
+ }
+ auto& input_shape = input_array.shape();
+ CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
+ const auto& reduction_array_vals =
+ reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto& output_dims = *output_array.mutable_shape()->mutable_dims();
+ output_dims.clear();
+ for (int i = 0; i < input_shape.dimensions_count(); i++) {
+ bool is_reduction_dim = false;
+ for (int r : reduction_array_vals) {
+ if (i == r) {
+ is_reduction_dim = true;
+ }
+ }
+ if (!is_reduction_dim) {
+ output_dims.push_back(input_shape.dims(i));
+ }
+ }
+ } else {
+ // No reduction_indices means complete reduction to a single scalar.
+ output_array.copy_shape(Shape({}));
+ }
+}
+
+void ProcessSliceOperator(Model* model, SliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 3);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Yield until the Slice params have been resolved.
+ if (op->begin.empty()) return;
+
+ // Yield until input dims have been resolved.
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+ const Shape& input_shape = input_array.shape();
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ CHECK_EQ(input_shape.dims().size(), op->size.size());
+ CHECK_EQ(op->begin.size(), op->size.size());
+
+ std::vector<int> output_dims;
+ for (int i = 0; i < op->begin.size(); ++i) {
+ int size = op->size[i];
+ if (size == -1) {
+ size = input_array.shape().dims(i) - op->begin[i];
+ }
+ output_dims.push_back(size);
+ }
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ const string& output_name = op->outputs[0];
+ Shape* output_shape = model->GetArray(output_name).mutable_shape();
+ ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
+ output_shape);
+}
+
+void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
+ // Yield until input dims have been resolved.
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ if (!input_array.has_shape()) {
+ return;
+ }
+ }
+ auto& output_array = model->GetArray(op->outputs[0]);
+ // Use 0 input as basis for output dimensions.
+ const auto& first_input_array = *model->arrays[op->inputs[0]];
+ output_array.copy_shape(first_input_array.shape());
+ // Determine the concat size, and enfore that all inputs have
+ // the same dimensions count.
+ int concat_size = 0;
+ for (const auto& input_name : op->inputs) {
+ auto& input_array = *model->arrays[input_name];
+ CHECK(input_array.has_shape());
+ if (input_array.shape().dimensions_count() == 0) {
+ continue;
+ }
+ CHECK_EQ(input_array.shape().dimensions_count(),
+ output_array.shape().dimensions_count());
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ CHECK_LT(op->concat_dim, input_dims.size());
+ concat_size += input_dims[op->concat_dim];
+ }
+ // Write out the concat_size on the output array shape.
+ auto& output_shape = *output_array.mutable_shape();
+ auto& output_dims = *output_shape.mutable_dims();
+ CHECK_LT(op->concat_dim, output_shape.dimensions_count());
+ output_dims[op->concat_dim] = concat_size;
+}
+
+void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ const string& input_name = op->inputs[1];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const Shape& input_shape = input_array.shape();
+
+ // This code is slightly suspect. The TensorFlow docs say that the axis
+ // selection defaults to 0, but we are splitting across the final axis.
+ const int input_dims_count = input_shape.dimensions_count();
+ const int input_depth = input_shape.dims(input_dims_count - 1);
+ CHECK_EQ(input_depth % op->num_split, 0);
+ const int split_depth = input_depth / op->num_split;
+
+ Shape output_shape = input_shape;
+ (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth;
+
+ CHECK_EQ(op->outputs.size(), op->num_split);
+ for (const auto& output : op->outputs) {
+ model->arrays[output]->copy_shape(output_shape);
+ }
+}
+
+void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
+ const string& input_name = op->inputs[0];
+ const auto& input_array = *model->arrays[input_name];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ if (input_shape.dimensions_count() < 4) {
+ LOG(FATAL) << "missing dimensions for " << input_name;
+ }
+ const string& output_name = op->outputs[0];
+ const int output_depth = input_shape.dims(3);
+ ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
+ op->stride_width, op->stride_height, op->padding.type,
+ model->GetArray(output_name).mutable_shape(),
+ &op->padding.GetOrCreateFixedPadding());
+}
+
+void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ if (!model->arrays[op->inputs[0]]->has_shape() ||
+ !model->arrays[op->inputs[1]]->has_shape()) {
+ return;
+ }
+ const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();
+
+ const string& output_size_name = op->inputs[1];
+ const auto& output_size_array = *model->arrays[output_size_name];
+ CHECK(output_size_array.data_type == ArrayDataType::kInt32);
+ CHECK(output_size_array.has_shape());
+ const auto& output_size_shape = output_size_array.shape();
+ CHECK_EQ(output_size_shape.dimensions_count(), 1);
+ CHECK_EQ(output_size_shape.dims(0), 2);
+ std::vector<int32> output_shape =
+ output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
+ input_data_shape.dims(3)}));
+}
+
+void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
+ // I/O arrays should be allocated on creation of op.
+ QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
+ QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);
+
+ const auto& input_array =
+ *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_GE(input_shape.dimensions_count(), 2);
+
+ const auto& prev_activ_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_activ_array.has_shape()) {
+ return;
+ }
+ const auto& prev_activ_shape = prev_activ_array.shape();
+ CHECK_GE(prev_activ_shape.dimensions_count(), 2);
+
+ const auto& weights_array =
+ *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
+ // Yield until weights dims have been resolved.
+ if (!weights_array.has_shape()) {
+ return;
+ }
+ const auto& weights_shape = weights_array.shape();
+ CHECK_EQ(weights_shape.dimensions_count(), 2);
+
+ const auto& bias_array =
+ *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
+ // Yield until bias dims have been resolved.
+ if (!bias_array.has_shape()) {
+ return;
+ }
+ const auto& bias_shape = bias_array.shape();
+ CHECK_GE(bias_shape.dimensions_count(), 1);
+
+ const auto& prev_state_array =
+ *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
+ // Yield until all input dims have been resolved.
+ if (!prev_state_array.has_shape()) {
+ return;
+ }
+ const auto& prev_state_shape = prev_state_array.shape();
+ CHECK_GE(prev_state_shape.dimensions_count(), 2);
+
+ const int fc_output_depth = weights_shape.dims(0);
+ CHECK_EQ(fc_output_depth, bias_shape.dims(0));
+ CHECK_EQ(fc_output_depth % 4, 0);
+ const int depth = fc_output_depth / 4;
+
+ const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
+ const int fc_input_depth = weights_shape.dims(1);
+ CHECK_EQ(input_depth + depth, fc_input_depth);
+ Shape output_shape(input_shape);
+ (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
+
+ // Set output dimensions
+ model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
+ .copy_shape(output_shape);
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
+ .copy_shape(output_shape);
+
+ Shape concat_temp_shape(input_shape);
+ (*concat_temp_shape
+ .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
+ fc_input_depth;
+ model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
+ .copy_shape(concat_temp_shape);
+
+ Shape activ_temp_shape(input_shape);
+ (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
+ fc_output_depth;
+ model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
+ .copy_shape(activ_temp_shape);
+}
+
+void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& paddings_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& paddings_array_shape = paddings_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(paddings_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!paddings_array.buffer) {
+ return;
+ }
+ QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
+ const auto& paddings_data =
+ paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
+ int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
+ int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
+ QCHECK_EQ(height_with_paddings % block_height, 0);
+ QCHECK_EQ(width_with_paddings % block_width, 0);
+ int output_height = height_with_paddings / block_height;
+ int output_width = width_with_paddings / block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) * block_height * block_width, output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) {
+ return;
+ }
+ const auto& input_shape = input_array.shape();
+ CHECK_EQ(input_shape.dimensions_count(), 4);
+ const auto input_height = input_shape.dims(1);
+ const auto input_width = input_shape.dims(2);
+
+ const auto& block_shape_array = *model->arrays[op->inputs[1]];
+ const auto& crops_array = *model->arrays[op->inputs[2]];
+ const auto& block_shape_array_shape = block_shape_array.shape();
+ const auto& crops_array_shape = crops_array.shape();
+ QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
+ QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
+
+ // We only support two dimensions.
+ QCHECK_EQ(block_shape_array_shape.dims(0), 2);
+ if (!block_shape_array.buffer) {
+ return;
+ }
+ QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
+ const auto& block_shape_data =
+ block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ auto block_height = block_shape_data[0];
+ auto block_width = block_shape_data[1];
+
+ QCHECK_EQ(crops_array_shape.dims(0), 2); // Number of block dimensions
+ QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
+ if (!crops_array.buffer) {
+ return;
+ }
+ QCHECK(crops_array.data_type == ArrayDataType::kInt32);
+ const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
+ // We don't support crops now.
+ QCHECK_EQ(crops_data[0], 0);
+ QCHECK_EQ(crops_data[1], 0);
+ QCHECK_EQ(crops_data[2], 0);
+ QCHECK_EQ(crops_data[3], 0);
+
+ QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);
+
+ int output_height = input_height * block_height;
+ int output_width = input_width * block_width;
+
+ model->arrays[op->outputs[0]]->copy_shape(
+ Shape({input_shape.dims(0) / (block_height * block_width), output_height,
+ output_width, input_shape.dims(3)}));
+}
+
+void ProcessGatherOperator(Model* model, GatherOperator* op) {
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ auto& output_array = *model->arrays[op->outputs[0]];
+
+ // Bail if we already know the output shape.
+ if (output_array.has_shape()) {
+ return;
+ }
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape() || !indices_array.has_shape()) {
+ return;
+ }
+
+ const auto& input_shape = input_array.shape();
+ const auto& indices_shape = indices_array.shape();
+ QCHECK_GE(input_shape.dimensions_count(), 1);
+ op->input_rank = input_shape.dimensions_count();
+
+ // We only support 1-D indices.
+ QCHECK_EQ(indices_shape.dimensions_count(), 1);
+
+ // Copy the input dimensions to the output except for dimension 0,
+ // where the dimension of indices_shape is used.
+ auto output_dims = output_array.mutable_shape()->mutable_dims();
+ output_dims->push_back(indices_shape.dims(0));
+ for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
+ output_dims->push_back(input_shape.dims(dim));
+ }
+}
+
+void ProcessPadOperator(Model* model, PadOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->left_padding.empty()) return;
+ CHECK_EQ(op->left_padding.size(), op->right_padding.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->left_padding.size(), dims.size());
+
+ for (int i = 0; i < op->left_padding.size(); ++i) {
+ dims[i] += op->left_padding[i] + op->right_padding[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessMeanOperator(Model* model, MeanOperator* op) {
+ CHECK_EQ(op->inputs.size(), 2);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+ const std::vector<int>& indices = op->reduction_indices;
+ if (indices.empty()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (std::find(indices.begin(), indices.end(), i) == indices.end()) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ CHECK(!output_dims.empty());
+ CHECK_EQ(output_dims.size(), 2);
+
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
+ CHECK_EQ(op->inputs.size(), 4);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ if (op->start_indices.empty()) return;
+ CHECK_EQ(op->start_indices.size(), op->stop_indices.size());
+ CHECK_EQ(op->start_indices.size(), op->strides.size());
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ Shape output_shape = input_array.shape();
+ std::vector<int>& dims = *output_shape.mutable_dims();
+ CHECK_EQ(op->start_indices.size(), dims.size());
+
+ for (int i = 0; i < op->start_indices.size(); ++i) {
+ const int mask = 1 << i;
+ const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
+ const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i]
+ : op->stop_indices[i];
+ dims[i] = (stop - start) / op->strides[i];
+ }
+
+ output_array.copy_shape(output_shape);
+}
+
+void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
+ CHECK_EQ(op->inputs.size(), 1);
+ CHECK_EQ(op->outputs.size(), 1);
+
+ const auto& input_array = *model->arrays[op->inputs[0]];
+
+ // Yield until input dims have been resolved.
+ if (!input_array.has_shape()) return;
+
+ auto& output_array = *model->arrays[op->outputs[0]];
+ if (output_array.has_shape()) return;
+
+ const std::vector<int>& input_dims = input_array.shape().dims();
+ std::vector<int> output_dims;
+
+ for (int i = 0; i < input_dims.size(); ++i) {
+ if (input_dims[i] != 1 ||
+ (!op->squeeze_dims.empty() &&
+ std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
+ op->squeeze_dims.end())) {
+ output_dims.push_back(input_dims[i]);
+ }
+ }
+ *output_array.mutable_shape()->mutable_dims() = output_dims;
+}
+
+void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
+ CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
+ const auto& input_array = *model->arrays[op->inputs[0]];
+ if (!input_array.has_shape()) return;
+
+ auto& weights_feature_array = *model->arrays[op->inputs[1]];
+ if (!weights_feature_array.has_shape()) return;
+
+ const auto& weights_time_array = *model->arrays[op->inputs[2]];
+ if (!weights_time_array.has_shape()) return;
+
+ const bool has_bias = (op->inputs.size() == 4);
+ if (has_bias) {
+ const auto& bias_array = *model->arrays[op->inputs[3]];
+ if (!bias_array.has_shape()) return;
+ }
+
+ const int batch_size = input_array.shape().dims()[0];
+ const int num_units = weights_feature_array.shape().dims()[0];
+ const int memory_size = weights_time_array.shape().dims()[1];
+
+ auto& state_array = model->GetArray(op->outputs[0]);
+ state_array.mutable_shape()->ReplaceDims(
+ {batch_size, memory_size * num_units});
+
+ auto& output_array = model->GetArray(op->outputs[1]);
+ output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
+}
+} // namespace
+
+bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ std::unordered_map<string, std::vector<int>> old_output_dims;
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape()) {
+ old_output_dims[output] = model->arrays[output]->shape().dims();
+ }
+ }
+
+ switch (op->type) {
+ case OperatorType::kBatchNormalization:
+ case OperatorType::kL2Normalization:
+ case OperatorType::kDequantize:
+ case OperatorType::kRelu:
+ case OperatorType::kRelu1:
+ case OperatorType::kRelu6:
+ case OperatorType::kSoftmax:
+ case OperatorType::kLogistic:
+ case OperatorType::kTanh:
+ case OperatorType::kLocalResponseNormalization:
+ case OperatorType::kTensorFlowIdentity:
+ case OperatorType::kFakeQuant:
+ case OperatorType::kTensorFlowRsqrt:
+ case OperatorType::kTensorFlowSqrt:
+ case OperatorType::kTensorFlowSquare:
+ case OperatorType::kTensorFlowAll:
+ case OperatorType::kTensorFlowAssert:
+ case OperatorType::kCast:
+ case OperatorType::kFloor:
+ ProcessSimpleOperator(model, op);
+ break;
+ case OperatorType::kGather:
+ ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
+ break;
+
+ case OperatorType::kAdd:
+ case OperatorType::kSub:
+ case OperatorType::kMul:
+ case OperatorType::kDiv:
+ case OperatorType::kTensorFlowLess:
+ case OperatorType::kTensorFlowLessEqual:
+ case OperatorType::kTensorFlowGreater:
+ case OperatorType::kTensorFlowMaximum:
+ case OperatorType::kTensorFlowMinimum:
+ case OperatorType::kTensorFlowGreaterEqual:
+ ProcessSimpleBinaryOperator(model, op);
+ break;
+ case OperatorType::kConv:
+ ProcessConvOperator(model, static_cast<ConvOperator*>(op));
+ break;
+ case OperatorType::kDepthwiseConv:
+ ProcessDepthwiseConvOperator(model,
+ static_cast<DepthwiseConvOperator*>(op));
+ break;
+ case OperatorType::kDepthToSpace:
+ ProcessDepthToSpaceOperator(model,
+ static_cast<DepthToSpaceOperator*>(op));
+ break;
+ case OperatorType::kSpaceToDepth:
+ ProcessSpaceToDepthOperator(model,
+ static_cast<SpaceToDepthOperator*>(op));
+ break;
+ case OperatorType::kFullyConnected:
+ ProcessFullyConnectedOperator(model,
+ static_cast<FullyConnectedOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowReshape:
+ ProcessTensorFlowReshapeOperator(
+ model, static_cast<TensorFlowReshapeOperator*>(op));
+ break;
+ case OperatorType::kAveragePool:
+ ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
+ break;
+ case OperatorType::kMaxPool:
+ ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
+ break;
+ case OperatorType::kL2Pool:
+ ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMin:
+ case OperatorType::kTensorFlowMax:
+ case OperatorType::kTensorFlowSum:
+ ProcessTensorFlowReductionOperator(model, op);
+ break;
+
+ case OperatorType::kSlice:
+ ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
+ break;
+
+ case OperatorType::kTensorFlowTile:
+ // We don't currently implement the propagation of fixed sizes through
+ // a TensorFlow Tile.
+ //
+ // Fortunately, we don't need to: so far, we have only dealt with Tile
+ // or Slice ops in subgraphs that are identified as L2Normalization.
+ // See IdentifyL2Normalization.
+ break;
+ case OperatorType::kTensorFlowSwitch:
+ // We can't know the sizes of the outputs until we have resolved the
+ // predicate, and once we have resolved the predicate, the whole
+ // Switch node will get resolved away.
+ // See ResolveTensorFlowSwitch.
+ break;
+ case OperatorType::kTensorFlowMerge:
+ // No need to bother resolving TensorFlow Merge ops: other graph
+ // transformations will remove them anyway.
+ // See ResolveTensorFlowMerge.
+ break;
+ case OperatorType::kTensorFlowSplit:
+ ProcessTensorFlowSplitOperator(model,
+ static_cast<TensorFlowSplitOperator*>(op));
+ break;
+ case OperatorType::kSqueeze:
+ ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowConcat:
+ case OperatorType::kTensorFlowConcatV2:
+ // Unimplemented, hopefully another graph transformation will
+ // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
+ // will resolve this node to a DepthConcatenation, or else we have
+ // a more general non-depth concatenation that will hopefully be dropped,
+ // or else at the moment we will abort.
+ break;
+ case OperatorType::kTensorFlowShape:
+ // Unimplemented, hopefully another graph transformation will drop it or
+ // rewrite it.
+ break;
+ case OperatorType::kReorderAxes:
+ ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
+ break;
+ case OperatorType::kConcatenation:
+ ProcessConcatenationOperator(model,
+ static_cast<ConcatenationOperator*>(op));
+ break;
+ case OperatorType::kResizeBilinear:
+ ProcessResizeBilinearOperator(model,
+ static_cast<ResizeBilinearOperator*>(op));
+ break;
+ case OperatorType::kLstmCell:
+ ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowMatMul:
+ // MatMul operators are converted to FullyConnected, after which their
+ // shapes are propagated.
+ break;
+ case OperatorType::kSpaceToBatchND:
+ ProcessSpaceToBatchNDOperator(model,
+ static_cast<SpaceToBatchNDOperator*>(op));
+ break;
+ case OperatorType::kBatchToSpaceND:
+ ProcessBatchToSpaceNDOperator(model,
+ static_cast<BatchToSpaceNDOperator*>(op));
+ break;
+ case OperatorType::kPad:
+ ProcessPadOperator(model, static_cast<PadOperator*>(op));
+ break;
+ case OperatorType::kMean:
+ ProcessMeanOperator(model, static_cast<MeanOperator*>(op));
+ break;
+ case OperatorType::kStridedSlice:
+ ProcessStridedSliceOperator(model,
+ static_cast<StridedSliceOperator*>(op));
+ break;
+ case OperatorType::kTensorFlowUnsupported:
+ break;
+ case OperatorType::kSvdf:
+ ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
+ break;
+ default:
+ // Unimplemented, another graph transformation should drop it.
+ LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
+ }
+
+ // Return true if any output dim changed, false if none changed.
+ // Assumption: no transformation clears an output shape, they only add shapes.
+ for (const auto& output : op->outputs) {
+ if (model->arrays[output]->has_shape() &&
+ (old_output_dims[output] != model->arrays[output]->shape().dims())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
new file mode 100644
index 0000000000..5551755ea7
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -0,0 +1,467 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool SupportsQuantization(const Operator& op) {
+ auto type = op.type;
+ if (type == OperatorType::kTensorFlowUnsupported) {
+ auto* unsupported = static_cast<const TensorFlowUnsupportedOperator*>(&op);
+ return unsupported->quantized;
+ }
+ return type == OperatorType::kConv || type == OperatorType::kDepthwiseConv ||
+ type == OperatorType::kFullyConnected ||
+ type == OperatorType::kConcatenation ||
+ type == OperatorType::kL2Normalization || type == OperatorType::kAdd ||
+ type == OperatorType::kAveragePool || type == OperatorType::kMaxPool ||
+ type == OperatorType::kLogistic || type == OperatorType::kSoftmax ||
+ type == OperatorType::kTensorFlowReshape ||
+ type == OperatorType::kMul || type == OperatorType::kSpaceToDepth ||
+ type == OperatorType::kDepthToSpace;
+}
+
+template <ArrayDataType A>
+std::unique_ptr<GenericBuffer> QuantizeBuffer(
+ const GenericBuffer& buffer,
+ const QuantizationParams& quantization_params) {
+ const auto inverse_scale = 1. / quantization_params.scale;
+ CHECK(buffer.type == ArrayDataType::kFloat);
+ const auto& float_buffer =
+ static_cast<const Buffer<ArrayDataType::kFloat>&>(buffer);
+ auto* quantized_buffer = new Buffer<A>;
+ quantized_buffer->data.resize(float_buffer.data.size());
+ const auto qmin = static_cast<int32>(std::numeric_limits<DataType<A>>::min());
+ const auto qmax = static_cast<int32>(std::numeric_limits<DataType<A>>::max());
+ for (std::size_t i = 0; i < float_buffer.data.size(); i++) {
+ const float src_val = float_buffer.data[i];
+ double scaled_val; // Astonishingly, using 'float' degrades accuracy just
+ // enough to make a few tests fail!
+ if (quantization_params.scale == 0) {
+ CHECK_EQ(src_val, 0) << "The quantization scale for this array is 0, "
+ << "so all its values should be 0.";
+ scaled_val = quantization_params.zero_point;
+ } else {
+ scaled_val = quantization_params.zero_point + inverse_scale * src_val;
+ }
+ const auto rounded_val = static_cast<int32>(std::round(scaled_val));
+ const auto clamped_val = std::min(qmax, std::max(qmin, rounded_val));
+ quantized_buffer->data[i] = static_cast<DataType<A>>(clamped_val);
+ }
+ return std::unique_ptr<GenericBuffer>(quantized_buffer);
+}
+
+template <ArrayDataType A>
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name,
+ const QuantizationParams& quantization_params) {
+ auto& array = model->GetArray(name);
+ CHECK(array.data_type == ArrayDataType::kFloat);
+ CHECK(!array.quantization_params);
+ array.GetOrCreateQuantizationParams() = quantization_params;
+ if (array.buffer) {
+ array.buffer = QuantizeBuffer<A>(*array.buffer, quantization_params);
+ }
+ array.data_type = A;
+ transformation->AddMessageF("Quantized array %s", name);
+}
+
+void QuantizeArray(GraphTransformation* transformation, Model* model,
+ const string& name, ArrayDataType quantized_data_type,
+ const QuantizationParams& quantization_params) {
+ switch (quantized_data_type) {
+ case ArrayDataType::kUint8:
+ return QuantizeArray<ArrayDataType::kUint8>(transformation, model, name,
+ quantization_params);
+ case ArrayDataType::kInt32:
+ return QuantizeArray<ArrayDataType::kInt32>(transformation, model, name,
+ quantization_params);
+ default:
+ LOG(FATAL) << "Unhandled case.";
+ }
+}
+
+const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
+ auto& array = model->GetArray(array_name);
+ // Normally we should have a MinMax recorded on this Array,
+ // so we just use it.
+ if (array.minmax != nullptr) {
+ return *array.minmax;
+ }
+
+ // We don't have a MinMax. That's bad news: we need
+ // the graph to provide MinMax info for all arrays in order
+ // for inference to reproduce faithfully the same quantization
+ // error as the training process had.
+ //
+ // But we still want to support a fallback for constant arrays,
+ // just using the plain min and max computed from array elements.
+ // We should hopefully never rely on that in production, as that
+ // will not give very good accuracy as that typically won't be
+ // exactly what the training process used. But it will be useful
+ // to allow easily trying out quantization even if the graph
+ // lacks some minmax information.
+ if (array.buffer != nullptr) {
+ LOG(WARNING)
+ << "Constant array " << array_name
+ << " lacks MinMax information. To make up for that, we will now compute"
+ << " the MinMax from actual array elements. That will result in"
+ << " quantization parameters that probably do not match whichever "
+ "arithmetic"
+ << " was used during training, and thus will probably be a cause of "
+ "poor"
+ << " inference accuracy.";
+ CHECK(array.buffer->type == ArrayDataType::kFloat);
+ const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
+ // We always want [min, max] to contain 0.
+ float min = 0.f;
+ float max = 0.f;
+ for (auto val : data) {
+ min = std::min(min, val);
+ max = std::max(max, val);
+ }
+ auto& minmax = array.GetOrCreateMinMax();
+ minmax.min = min;
+ minmax.max = max;
+ return minmax;
+ }
+
+ LOG(FATAL) << "Array " << array_name
+ << " does not have MinMax information, "
+ "and is not a constant array. Cannot "
+ "proceed with quantization.";
+}
+
+bool ChooseQuantizationForOperatorInput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t input_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& input = op.inputs[input_index];
+ auto& array = model->GetArray(input);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (op.type == OperatorType::kConv ||
+ op.type == OperatorType::kDepthwiseConv ||
+ op.type == OperatorType::kFullyConnected) {
+ if (input_index == 2) {
+ // Quantization of bias vector.
+ // We need both of the mandatory inputs (input activations and weights) to
+ // have
+ // been already quantized.
+ const auto& input_activations = model->GetArray(op.inputs[0]);
+ const auto& input_weights = model->GetArray(op.inputs[1]);
+ if (!input_activations.quantization_params ||
+ !input_weights.quantization_params) {
+ return false;
+ }
+ const auto input_activations_scale =
+ input_activations.quantization_params->scale;
+ const auto input_weights_scale = input_weights.quantization_params->scale;
+ quantization_params->scale =
+ input_activations_scale * input_weights_scale;
+ quantization_params->zero_point = 0;
+ *quantized_data_type = ArrayDataType::kInt32;
+ transformation->AddMessageF(
+ "Input array %s is a bias vector. Choosing quantization params "
+ "accordingly.",
+ input);
+ return true;
+ }
+ }
+
+ const MinMax& minmax = GetOrComputeMinMax(model, input);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ transformation->AddMessageF(
+ "For input array %s with min=%g"
+ ", max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ input, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+ *quantized_data_type = ArrayDataType::kUint8;
+ return true;
+}
+
+bool IsExactlyRepresentable(double real_value, ArrayDataType data_type,
+ const QuantizationParams& quantization_params) {
+ const double scaled_value =
+ quantization_params.zero_point + real_value / quantization_params.scale;
+ const double fractional_scaled_value =
+ scaled_value - std::round(scaled_value);
+ if (std::abs(fractional_scaled_value) > 1e-12) {
+ return false;
+ }
+ const double rounded_scaled_value = std::round(scaled_value);
+ if (data_type == ArrayDataType::kUint8) {
+ if (rounded_scaled_value < 0 || rounded_scaled_value > 255) {
+ return false;
+ }
+ }
+ return true;
+}
+
+bool ChooseHardcodedQuantizationForOperatorOutput(
+ const Operator& op, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ if (op.type == OperatorType::kL2Normalization) {
+ // L2Normalization has range: [-1, 1].
+ // 0 should be exactly representable, as values will typically be centered
+ // around 0, with many values near 0.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 128;
+ quantization_params->scale = 1. / 128.;
+ CHECK(
+ IsExactlyRepresentable(0., *quantized_data_type, *quantization_params));
+ return true;
+ }
+ if ((op.type == OperatorType::kLogistic) ||
+ (op.type == OperatorType::kSoftmax)) {
+ // Logistic and Softmax have range: [0, 1].
+ //
+ // For Logistic, 0.5 should be exactly representable, as implementations
+ // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and
+ // the glueing of the two halves of the graph will only be seamless if we
+ // are accurately representing logistic(0) == 0.5.
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = 0;
+ quantization_params->scale = 1. / 256.;
+ CHECK(IsExactlyRepresentable(0.5, *quantized_data_type,
+ *quantization_params));
+ return true;
+ }
+ return false;
+}
+
+bool ChooseQuantizationForOperatorOutput(
+ GraphTransformation* transformation, Model* model, const Operator& op,
+ std::size_t output_index, ArrayDataType* quantized_data_type,
+ QuantizationParams* quantization_params) {
+ const auto& output = op.outputs[output_index];
+ auto& array = model->GetArray(output);
+ if (array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type,
+ quantization_params)) {
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Choosing fixed "
+ "quantization params accordingly.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ if ((op.type == OperatorType::kDepthToSpace) ||
+ (op.type == OperatorType::kSpaceToDepth)) {
+ // DepthToSpace and SpaceToDepth should preserve the quantization parameters
+ // of the input array, as these are simple reshape operations.
+ const auto& input_quantization_params =
+ model->GetArray(op.inputs[0]).GetQuantizationParams();
+ *quantized_data_type = ArrayDataType::kUint8;
+ quantization_params->zero_point = input_quantization_params.zero_point;
+ quantization_params->scale = input_quantization_params.scale;
+
+ transformation->AddMessageF(
+ "Output array %s is produced by a %s operator. Copying quantization "
+ "params from input array.",
+ output, OperatorTypeName(op.type));
+ return true;
+ }
+ const MinMax& minmax = GetOrComputeMinMax(model, output);
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(model->flags, minmax,
+ quantization_params);
+ *quantized_data_type = ArrayDataType::kUint8;
+ transformation->AddMessageF(
+ "For output array %s with min=%g, max=%g"
+ ", chose to quantize as uint8 with zero_point=%d"
+ ", scale=%g",
+ output, minmax.min, minmax.max, quantization_params->zero_point,
+ quantization_params->scale);
+
+ return true;
+}
+} // namespace
+
+bool Quantize::Run(Model* model, std::size_t op_index) {
+ // Our general "quantization" graph transformation consists in replacing
+ // QuantizedInputArrays[] ->
+ // DequantizeOperators[] ->
+ // FloatInputArrays[] ->
+ // Operator ->
+ // FloatOutputArray
+ // by
+ // QuantizedInputArrays[] ->
+ // Operator ->
+ // QuantizedOutputArray ->
+ // DequantizeOperator ->
+ // FloatOutputArray
+ //
+ // In other words, this is pushing Dequantize operators to the right of
+ // other operators.
+ //
+
+ auto& op = *model->operators[op_index];
+ if (op.type == OperatorType::kDequantize ||
+ op.type == OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ // Our assumption here is that the input arrays are already quantized -
+ // that is typically the case in models operating on an input bitmap
+ // image, and MakeInitialDequantizeOp should have already resolved
+ // the handling of the input image as an initial Dequantize op.
+ //
+ // Thus we are building around the assumption that the graph always starts
+ // with a quantized input array, and only after some Dequantize op do we have
+ // float arrays. The problem of quantizing the graph thus becomes a problem of
+ // pushing Dequantize ops to the right of other ops.
+ //
+ // Let us just guard this assumption by the following assertion:
+ for (const auto& input : op.inputs) {
+ if (IsInputArray(*model, input)) {
+ const auto& input_array = model->GetArray(input);
+ CHECK(input_array.quantization_params);
+ }
+ }
+ if (!SupportsQuantization(op)) {
+ LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
+ << HelpfulOperatorTypeName(op)
+ << " for which the quantized form is not yet implemented. "
+ "Sorry, and patches welcome (that's a relatively fun patch "
+ "to write, mostly providing the actual quantized arithmetic "
+ "code for this op).";
+ }
+
+ for (const auto& input : op.inputs) {
+ const auto& array = model->GetArray(input);
+ if (array.data_type == ArrayDataType::kFloat) {
+ if (!array.minmax && !array.buffer) {
+ LOG(ERROR) << "Can't quantize input array " << input
+ << " because it lacks min/max info";
+ return false;
+ }
+ const auto* other_op = GetOpWithOutput(*model, input);
+ if (other_op && other_op->type != OperatorType::kDequantize) {
+ AddMessageF(
+ "Not quantizing %s for now, because its input array %s is not "
+ "produced by a Dequantize op, "
+ "which means that we should yield and let other ops "
+ "get quantized first",
+ LogName(op), input);
+ return false;
+ }
+ }
+ }
+
+ bool changed = false;
+
+ // Quantize inputs, remove any Dequantize op on the inputs side
+ for (std::size_t input_index = 0; input_index < op.inputs.size();
+ input_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorInput(this, model, op, input_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& input = op.inputs[input_index];
+ if (IsConstantParameterArray(*model, input)) {
+ QuantizeArray(this, model, input, quantized_data_type,
+ quantization_params);
+ } else {
+ auto dequantize_it = FindOpWithOutput(*model, input);
+ CHECK(dequantize_it != model->operators.end());
+ auto* dequantize_op = dequantize_it->get();
+ CHECK(dequantize_op->type == OperatorType::kDequantize);
+ op.inputs[input_index] = dequantize_op->inputs[0];
+ // Check if the output of that Dequantize op was not used by any
+ // other operator. We will then erase that Dequantize op.
+ if (!CountOpsWithInput(*model, dequantize_op->outputs[0])) {
+ // If any of the model's output_arrays was pointing to the
+ // Dequantize op's output, let it point to the Dequantize op's
+ // input instead.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == dequantize_op->outputs[0]) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+ model->arrays.erase(dequantize_op->outputs[0]);
+ model->operators.erase(dequantize_it);
+ }
+ }
+ }
+ }
+
+ // Quantize outputs, add Dequantize ops as needed on the outputs side
+ for (std::size_t output_index = 0; output_index < op.outputs.size();
+ output_index++) {
+ ArrayDataType quantized_data_type;
+ QuantizationParams quantization_params;
+ if (ChooseQuantizationForOperatorOutput(this, model, op, output_index,
+ &quantized_data_type,
+ &quantization_params)) {
+ changed = true;
+ const auto& output = op.outputs[output_index];
+ QuantizeArray(this, model, output, quantized_data_type,
+ quantization_params);
+ const auto& dequantized_output =
+ AvailableArrayName(*model, output + "_dequantized");
+ const auto& output_array = model->GetArray(output);
+ const auto& output_minmax = output_array.GetMinMax();
+ auto& dequantized_output_array =
+ model->GetOrCreateArray(dequantized_output);
+ dequantized_output_array.data_type = ArrayDataType::kFloat;
+ auto& dequantized_output_minmax =
+ dequantized_output_array.GetOrCreateMinMax();
+ dequantized_output_minmax.min = output_minmax.min;
+ dequantized_output_minmax.max = output_minmax.max;
+ for (const auto& other_op : model->operators) {
+ for (auto& other_op_input : other_op->inputs) {
+ if (other_op_input == output) {
+ other_op_input = dequantized_output;
+ }
+ }
+ }
+ auto* dequantize_op = new DequantizeOperator;
+ dequantize_op->inputs = {output};
+ dequantize_op->outputs = {dequantized_output};
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (model->flags.output_arrays(i) == output) {
+ model->flags.set_output_arrays(i, dequantized_output);
+ }
+ }
+ const auto op_it = FindOp(*model, &op);
+ model->operators.emplace(op_it + 1, dequantize_op);
+ }
+ }
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
new file mode 100644
index 0000000000..371ced388a
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_fake_quant_min_max.cc
@@ -0,0 +1,105 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool ApplyMinMaxToArray(GraphTransformation* transformation, Model* model,
+ const MinMax& minmax, const string& array_name) {
+ auto& annotated_array = model->GetArray(array_name);
+ if (annotated_array.minmax) {
+ return false;
+ }
+ annotated_array.GetOrCreateMinMax() = minmax;
+ transformation->AddMessageF(
+ "Read min/max annotation for array %s: min=%g, max=%g", array_name,
+ minmax.min, minmax.max);
+ return true;
+}
+
+} // end namespace
+
+bool ReadFakeQuantMinMax::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+ auto* fakequant_op = static_cast<FakeQuantOperator*>(fakequant_base_op);
+
+ bool changed = false;
+
+ if (!fakequant_op->minmax) {
+ CHECK_EQ(fakequant_op->inputs.size(), 3);
+ // We need to yield until the min and max parameters have been
+ // resolved to constant arrays.
+ for (int i = 1; i <= 2; i++) {
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[1])) {
+ return false;
+ }
+ }
+
+ // Obtain the final min/max values
+ const auto& min_array = model->GetArray(fakequant_op->inputs[1]);
+ const auto& max_array = model->GetArray(fakequant_op->inputs[2]);
+ CHECK_EQ(RequiredBufferSizeForShape(min_array.shape()), 1);
+ CHECK_EQ(RequiredBufferSizeForShape(max_array.shape()), 1);
+ fakequant_op->minmax.reset(new MinMax);
+ MinMax& minmax = *fakequant_op->minmax;
+ minmax.min = min_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ minmax.max = max_array.GetBuffer<ArrayDataType::kFloat>().data[0];
+ // We always want [min, max] to contain 0.
+ minmax.min = std::min(minmax.min, 0.);
+ minmax.max = std::max(minmax.max, 0.);
+
+ // We won't use the input arrays that provided these min and max
+ // values, anymore. Delete them unless they are used by something
+ // else.
+ for (int i = 1; i <= 2; i++) {
+ if (CountOpsWithInput(*model, fakequant_op->inputs[i]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[i]);
+ }
+ }
+ fakequant_op->inputs.resize(1);
+ changed = true;
+ }
+
+ // At this point, this FakeQuantOperator should have a MinMax
+ // attached to it, and should only have 1 input (it should not have
+ // 2nd and 3rd input arrays giving min and max anymore).
+ CHECK(fakequant_op->minmax);
+ CHECK_EQ(1, fakequant_op->inputs.size());
+
+ const MinMax& minmax = *fakequant_op->minmax;
+
+ // Record the MinMax info on the input and output arrays
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->inputs[0]);
+ changed |= ApplyMinMaxToArray(this, model, minmax, fakequant_op->outputs[0]);
+
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
new file mode 100644
index 0000000000..3992e7d1ef
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_final_dequantize_op.cc
@@ -0,0 +1,59 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveFinalDequantizeOp::Run(Model* model, std::size_t op_index) {
+ const auto dequantize_it = model->operators.begin() + op_index;
+ const auto* dequantize_op = dequantize_it->get();
+ if (dequantize_op->type != OperatorType::kDequantize) {
+ return false;
+ }
+ const auto& output = dequantize_op->outputs[0];
+ // We can remove any dequantize op whose output is not consumed by
+ // any op. This is not necessarily equivalent to the output being
+ // one of the model's output arrays, as some intermediate array
+ // in the middle of the graph might be designated as an output
+ // array.
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+
+ // If one of the model's output arrays was actually the Dequantize op's
+ // output, then we need to update it to point to the Dequantize op's input.
+ for (int i = 0; i < model->flags.output_arrays_size(); i++) {
+ if (output == model->flags.output_arrays(i)) {
+ model->flags.set_output_arrays(i, dequantize_op->inputs[0]);
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removed final %s", LogName(*dequantize_op));
+ model->arrays.erase(output);
+ model->operators.erase(dequantize_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
new file mode 100644
index 0000000000..35a0c46532
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_assert.cc
@@ -0,0 +1,60 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowAssert::Run(Model* model, std::size_t op_index) {
+ const auto assert_it = model->operators.begin() + op_index;
+ const auto* assert_op = assert_it->get();
+ if (assert_op->type != OperatorType::kTensorFlowAssert) {
+ return false;
+ }
+
+ bool changed = false;
+ // Remove any other node's dependency on this assert node
+ for (const auto& op : model->operators) {
+ auto it = op->inputs.begin();
+ while (it != op->inputs.end()) {
+ if (*it == assert_op->outputs[0]) {
+ op->inputs.erase(it);
+ changed = true;
+ } else {
+ ++it;
+ }
+ }
+ }
+ CHECK(!CountOpsWithInput(*model, assert_op->outputs[0]));
+
+ if (changed) {
+ AddMessageF(
+ "Prepared for the removal of %s by removing any other op's dependency "
+ "on it",
+ LogName(*assert_op));
+ }
+
+ // That's it. We can stop here, no need to duplicate the work that
+ // RemoveUnusedOp will do removing this now-unused node.
+ return changed;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
new file mode 100644
index 0000000000..404269bbfd
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_tensorflow_identity.cc
@@ -0,0 +1,38 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTensorFlowIdentity::Run(Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ const auto* passthru_op = passthru_it->get();
+ if (passthru_op->type != OperatorType::kTensorFlowIdentity) {
+ return false;
+ }
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
new file mode 100644
index 0000000000..6add443f2d
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_binary.cc
@@ -0,0 +1,113 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+template <typename Scalar>
+bool AreAllBufferElementsEqualTo(const std::vector<Scalar>& buffer_data,
+ Scalar value) {
+ for (auto x : buffer_data) {
+ if (x != value) {
+ return false;
+ }
+ }
+ return true;
+}
+} // namespace
+
+// A binary operator is called trivial when exactly one of its operands is
+// a constant and is such that the binary operation is equivalent to
+// the identity operation on its other input.
+// For example, an Add operator is trivial if
+// one of its operands is constant 0, a Mul operator is trivial
+// if one of its operands is constant 1, etc.
+bool RemoveTrivialBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ // This graph transformation is only concerned with the case
+ // when one input is constant and the other is not constant.
+ const bool is_input_constant[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!is_input_constant[0] && !is_input_constant[1]) {
+ // Neither input is constant, so nothing we can resolve here.
+ return false;
+ }
+ if (is_input_constant[0] && is_input_constant[1]) {
+ // Both inputs are constants. That's a job for constants
+ // propagation, not for us to handle here.
+ return false;
+ }
+ const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
+ const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
+ CHECK(is_input_constant[index_of_constant_input]);
+ CHECK(!is_input_constant[index_of_variable_input]);
+
+ // Now check if the constant operand makes this binary
+ // operator trivial.
+ const auto& constant_input_array =
+ *model->arrays[binary_op->inputs[index_of_constant_input]];
+ // For now, we only handle floats here.
+ if (constant_input_array.data_type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& constant_input_float_data =
+ constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ bool is_trivial = false;
+ if (binary_op->type != OperatorType::kAdd) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kSub) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 0.f);
+ } else if (binary_op->type != OperatorType::kMul) {
+ is_trivial = AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ } else if (binary_op->type != OperatorType::kDiv) {
+ is_trivial = index_of_constant_input == 1 &&
+ AreAllBufferElementsEqualTo(constant_input_float_data, 1.f);
+ }
+
+ if (!is_trivial) {
+ return false;
+ }
+
+ // Now we know that this node is trivial, so we can remove it.
+ AddMessageF("Removing trivial %s", LogName(*binary_op));
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
new file mode 100644
index 0000000000..3ceb93d8ee
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation.cc
@@ -0,0 +1,40 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ if (concat_op->inputs.size() != 1) {
+ return false;
+ }
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
new file mode 100644
index 0000000000..b603735704
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_concatenation_input.cc
@@ -0,0 +1,68 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialConcatenationInput::Run(Model* model, std::size_t op_index) {
+ // TensorFlow allows Concatenation nodes to have 0-D inputs,
+ // and they are then treated as empty i.e. omitted from concatenation,
+ // in violation of the notion that 0-D is equivalent to 1x1x1x1.
+ // Thus we have to drop these 0-D inputs from Concatenation nodes.
+ // Sometimes, there will remain only one non-trivial input, and
+ // the other graph transformation RemoveTrivialConcatenation will then drop
+ // it.
+ const auto concat_it = model->operators.begin() + op_index;
+ auto* concat_op = concat_it->get();
+ if (concat_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ std::vector<string> trivial_inputs;
+ std::vector<string> nontrivial_inputs;
+ for (const string& input : concat_op->inputs) {
+ const auto& input_array = model->GetArray(input);
+ const bool is_trivial =
+ input_array.has_shape() && input_array.shape().dimensions_count() == 0;
+ if (is_trivial) {
+ trivial_inputs.push_back(input);
+ } else {
+ nontrivial_inputs.push_back(input);
+ }
+ }
+
+ if (trivial_inputs.empty()) {
+ return false;
+ }
+
+ // Drop trivial inputs.
+ for (const string& input : trivial_inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ concat_op->inputs = nontrivial_inputs;
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
new file mode 100644
index 0000000000..a0d1338298
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.cc
@@ -0,0 +1,107 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+// Reroute all edges involving a given discardable array to another
+// array instead. from_array is assumed to be discardable, and consequently
+// this only updates operator edges (since discardable arrays only
+// appear there, and not e.g. in model flags).
+void RerouteEdges(const string& from_array, const string& to_array,
+ Model* model) {
+ for (const auto& op : model->operators) {
+ for (auto& output : op->outputs) {
+ if (output == from_array) {
+ output = to_array;
+ }
+ }
+ for (auto& input : op->inputs) {
+ if (input == from_array) {
+ input = to_array;
+ }
+ }
+ }
+}
+
+} // end anonymous namespace
+
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index) {
+ const auto passthru_it = model->operators.begin() + op_index;
+ auto* passthru_op = passthru_it->get();
+ CHECK_EQ(passthru_op->outputs.size(), 1);
+ CHECK_GE(passthru_op->inputs.size(), 1);
+ int count_nonconstant_input_arrays = 0;
+ // We call 'main input' the unique nonconstant input array if there is one,
+ // or else the 0-th input.
+ int main_input_array_index = 0;
+ for (int i = 0; i < passthru_op->inputs.size(); i++) {
+ if (!model->GetArray(passthru_op->inputs[i]).buffer) {
+ count_nonconstant_input_arrays++;
+ main_input_array_index = i;
+ }
+ }
+ CHECK_LE(count_nonconstant_input_arrays, 1);
+
+ const string main_input_name = passthru_op->inputs[main_input_array_index];
+ const string output_name = passthru_op->outputs[0];
+ if (IsDiscardableArray(*model, output_name)) {
+ transformation->AddMessageF(
+ "Removing %s, keeping its non-constant input array",
+ LogName(*passthru_op));
+ model->arrays.erase(output_name);
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) && input != main_input_name &&
+ CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(output_name, main_input_name, model);
+ } else if (IsDiscardableArray(*model, main_input_name)) {
+ transformation->AddMessageF("Removing %s, keeping its output array",
+ LogName(*passthru_op));
+ for (const string& input : passthru_op->inputs) {
+ if (IsDiscardableArray(*model, input) &&
+ (input == main_input_name || CountOpsWithInput(*model, input) == 1)) {
+ model->arrays.erase(input);
+ }
+ }
+ RerouteEdges(main_input_name, output_name, model);
+ } else {
+ transformation->AddMessageF(
+ "Cannot remove %s, neither its nonconstant input nor its output may be "
+ "discarded",
+ LogName(*passthru_op));
+ return false;
+ }
+
+ // Remove the pass-through node.
+ model->operators.erase(passthru_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
new file mode 100644
index 0000000000..b72c85c0e5
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+
+namespace toco {
+
+// A "passthrough op" is an op that satisfies the following conditions:
+// 1. It has at most one non-constant input (it may have other constant
+// inputs).
+// 2. It has exactly one output.
+// 3. It forwards exactly its single non-constant input to its single output.
+//
+// Examples include:
+// 1. TensorFlow Identity ops. (Have one input).
+// 2. TensorFlow Reshape ops when the input and output shapes agree.
+// 3. Any binary operator, one of whose two inputs is a constant and is the
+// neutral value for that operation. For example, a binary Add operator
+// where one of its inputs is a constant array filled with zeros.
+//
+// A passthrough op is "trivial" and can be removed when it is possible to
+// discard either its single non-constant input or output array, rerouting any
+// edge involving it to the other of these two arrays.
+//
+// It is only possible to discard such an array if it is not explicitly
+// designated as a global input/output array of the graph, e.g. the model's
+// input arrays, output arrays, and any array involved in a RNN back-edge
+// specified by the model.
+//
+// This function does not check that the given operator is a passthrough op:
+// that's the responsibility of the caller.
+// Given that it is a passthrough op, this function checks whether it is trivial
+// and then discards it and returns true, or, if it's not trivial (if neither
+// the input nor the output may be discarded), returns false.
+bool RemoveTrivialPassthroughOp(GraphTransformation* transformation,
+ Model* model, std::size_t op_index);
+
+} // namespace toco
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_TOCO_GRAPH_TRANSFORMATIONS_REMOVE_TRIVIAL_PASSTHROUGH_H_
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
new file mode 100644
index 0000000000..28f76c9d36
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_quantized_activation_func.cc
@@ -0,0 +1,87 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/toco_types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveTrivialQuantizedActivationFunc::Run(Model* model,
+ std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+ if (op->fused_activation_function != FusedActivationFunctionType::kRelu &&
+ op->fused_activation_function != FusedActivationFunctionType::kRelu6) {
+ return false;
+ }
+ const auto& output_array = model->GetArray(op->outputs[0]);
+ if (!output_array.quantization_params) {
+ return false;
+ }
+ if (output_array.data_type != ArrayDataType::kUint8) {
+ return false;
+ }
+ const auto& quantization_params = output_array.GetQuantizationParams();
+
+ bool has_nontrivial_min_bound = false;
+ bool has_nontrivial_max_bound = false;
+
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu ||
+ op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double lowest_representable_output =
+ (0. - quantization_params.zero_point) * quantization_params.scale;
+ if (lowest_representable_output < 0.) {
+ has_nontrivial_min_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the lowest representable output value %g"
+ " less than the clamp min bound.",
+ lowest_representable_output);
+ }
+ }
+ if (op->fused_activation_function == FusedActivationFunctionType::kRelu6) {
+ double highest_representable_output =
+ (255. - quantization_params.zero_point) * quantization_params.scale;
+ if (highest_representable_output > 6.) {
+ has_nontrivial_max_bound = true;
+ AddMessageF(
+ "Quantized activation function is not trivial: "
+ "the highest representable output value %g"
+ " is greater than the clamp max bound.",
+ highest_representable_output);
+ }
+ }
+
+ if (has_nontrivial_min_bound || has_nontrivial_max_bound) {
+ return false;
+ }
+
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+ AddMessageF(
+ "Removing trivial quantized activation function on %s"
+ " because the output quantization parameters imply at least as tight"
+ " a clamp anyway.",
+ LogName(*op));
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
new file mode 100644
index 0000000000..90f9381ec1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_reshape.cc
@@ -0,0 +1,92 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool IsReshapeTrivial(const Model& model, const Operator& op,
+ RemoveTrivialReshape* transformation) {
+ CHECK(op.type == OperatorType::kTensorFlowReshape);
+
+ // One way in which a reshape can be trivial is if its
+ // output shape is == its input shape
+ const auto& input_array = model.GetArray(op.inputs[0]);
+ const auto& output_array = model.GetArray(op.outputs[0]);
+ if (input_array.has_shape() && output_array.has_shape()) {
+ if (transformation->treat_expand_dims_as_trivial() &&
+ ShapesAgreeUpToExtending(input_array.shape(), output_array.shape())) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal up to "
+ "extending "
+ "by 1's, and we are told to aggressively discard such Reshape ops.",
+ LogName(op));
+ return true;
+ }
+ if (input_array.shape().dims() == output_array.shape().dims()) {
+ transformation->AddMessageF(
+ "%s is trivial because its input and output shapes are equal",
+ LogName(op));
+ return true;
+ }
+ }
+
+ // Another way in which a reshape can be trivial is if its output
+ // is only consumed by another reshape.
+ if (CountOpsWithInput(model, op.outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(model, op.outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ transformation->AddMessageF(
+ "%s is trivial because its output is only consumed by another "
+ "Reshape op",
+ LogName(op));
+ return true;
+ }
+ }
+
+ return false;
+}
+
+} // namespace
+
+bool RemoveTrivialReshape::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ if (!IsReshapeTrivial(*model, *reshape_op, this)) {
+ return false;
+ }
+
+ AddMessageF("Removing trivial %s", LogName(*reshape_op));
+
+ CHECK_EQ(reshape_op->inputs.size(), 2);
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
new file mode 100644
index 0000000000..1f1f1f6948
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/remove_unused_op.cc
@@ -0,0 +1,122 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/model_flags.pb.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool RemoveUnusedOp::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ const auto* op = it->get();
+
+ // Bail if any output is used, and is not an input_array of
+ // the model. We allow specifying an arbitrary input_array,
+ // treating the part of the graph leading up to it as unused.
+ for (const auto& output : op->outputs) {
+ CHECK(model->arrays.count(output));
+ // If this output is provided as the model's input array,
+ // then we don't need this operator to produce its contents.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // If this output is provided as a RNN's state array,
+ // then we don't need this operator to produce its contents.
+ // So far this case has only been encountered with TensorFlow
+ // Fill ops used to zero-initialize RNN states, which is
+ // redundant for us as we zero-initialize RNN states anyway.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ CHECK(op->type == OperatorType::kTensorFlowUnsupported);
+ CHECK_EQ(static_cast<const TensorFlowUnsupportedOperator*>(op)
+ ->tensorflow_op,
+ "Fill");
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (output == output_array) {
+ return false;
+ }
+ }
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ }
+ if (CountOpsWithInput(*model, output)) {
+ return false;
+ }
+ }
+
+ if (op->unresolved_outputs) {
+ AddMessageF("Not discarding %s because it has unresolved outputs.",
+ LogName(*op));
+ return false;
+ }
+
+ AddMessageF("Discarding %s because none of its outputs is used.",
+ LogName(*op));
+
+ // At that point we know that none of the outputs is used, so we will
+ // definitely remove the node and all its outputs.
+
+ // Remove any input array that is not used by anything else,
+ // and that is not the output of some other operator.
+ for (const auto& input : op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+
+ // Remove the node and its now-unused output arrays.
+ for (const auto& output : op->outputs) {
+ // If the output array is the model's input array, don't remove that.
+ // That's the case when cropping a model at a given --input_array.
+ if (IsInputArray(*model, output)) {
+ continue;
+ }
+ // Likewise, if the output array is a RNN state array, don't remove that.
+ bool found_output_as_rnn_state_array = false;
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (output == rnn_state.state_array()) {
+ found_output_as_rnn_state_array = true;
+ break;
+ }
+ }
+ if (found_output_as_rnn_state_array) {
+ continue;
+ }
+ // Generic case: do delete this output array.
+ model->arrays.erase(output);
+ }
+ model->operators.erase(it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
new file mode 100644
index 0000000000..3eb7fa3896
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc
@@ -0,0 +1,135 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) {
+ auto bn_it = model->operators.begin() + op_index;
+ if (bn_it->get()->type != OperatorType::kBatchNormalization) {
+ return false;
+ }
+ const auto* bn_op =
+ static_cast<const BatchNormalizationOperator*>(bn_it->get());
+
+ const auto& mean_array = model->GetArray(bn_op->inputs[1]);
+ const auto& multiplier_array = model->GetArray(bn_op->inputs[2]);
+ const auto& offset_array = model->GetArray(bn_op->inputs[3]);
+
+ CHECK(IsConstantParameterArray(*model, bn_op->inputs[1]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[2]) &&
+ IsConstantParameterArray(*model, bn_op->inputs[3]))
+ << "Batch normalization resolution requires that mean, multiplier and "
+ "offset arrays be constant.";
+
+ // We should only have *float* BatchNormalizations... let's guard this
+ // assumption by CHECK's.
+ CHECK(mean_array.data_type == ArrayDataType::kFloat);
+ CHECK(multiplier_array.data_type == ArrayDataType::kFloat);
+ CHECK(offset_array.data_type == ArrayDataType::kFloat);
+
+ // Create the new Mul, Add operators
+ auto* mul_op = new MulOperator;
+ auto* add_op = new AddOperator;
+ const string mul_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_mul");
+ const string add_name =
+ AvailableArrayName(*model, bn_op->outputs[0] + "_add");
+ const string mul_param_name = AvailableArrayName(*model, mul_name + "_param");
+ const string add_param_name = AvailableArrayName(*model, add_name + "_param");
+ mul_op->inputs = {bn_op->inputs[0], mul_param_name};
+ mul_op->outputs = {mul_name};
+ add_op->inputs = {mul_name, add_param_name};
+ add_op->outputs = {bn_op->outputs[0]};
+ AddMessageF("Splitting %s into %s and %s", LogName(*bn_op), LogName(*mul_op),
+ LogName(*add_op));
+
+ // Create the intermediate activation array (output of mul, input of add)
+ auto& intermediate_array = model->GetOrCreateArray(mul_op->outputs[0]);
+ intermediate_array.data_type = model->GetArray(bn_op->inputs[0]).data_type;
+
+ // Insert the new operators in the graph
+ auto add_it = model->operators.emplace(bn_it, add_op);
+ auto mul_it = model->operators.emplace(add_it, mul_op);
+ // update invalidated iterators.
+ DCHECK_EQ(mul_it->get(), mul_op);
+ add_it = mul_it + 1;
+ DCHECK_EQ(add_it->get(), add_op);
+ bn_it = add_it + 1;
+ DCHECK_EQ(bn_it->get(), bn_op);
+
+ // Create the new param arrays
+ const auto& mean_shape = mean_array.shape();
+ const auto& multiplier_shape = multiplier_array.shape();
+ const auto& offset_shape = offset_array.shape();
+ CHECK(mean_shape.dims() == multiplier_shape.dims());
+ CHECK(mean_shape.dims() == offset_shape.dims());
+ const auto& param_shape = mean_shape;
+ const int buffer_size = RequiredBufferSizeForShape(param_shape);
+ auto& mul_param_array = model->GetOrCreateArray(mul_param_name);
+ auto& add_param_array = model->GetOrCreateArray(add_param_name);
+ DropMinMax(model, mul_param_name);
+ DropMinMax(model, add_param_name);
+ mul_param_array.copy_shape(param_shape);
+ add_param_array.copy_shape(param_shape);
+ mul_param_array.data_type = ArrayDataType::kFloat;
+ add_param_array.data_type = ArrayDataType::kFloat;
+ auto& mul_float_data =
+ mul_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ auto& add_float_data =
+ add_param_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ mul_float_data.resize(buffer_size);
+ add_float_data.resize(buffer_size);
+ const auto& mean_float_data =
+ mean_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& multiplier_float_data =
+ multiplier_array.GetBuffer<ArrayDataType::kFloat>().data;
+ const auto& offset_float_data =
+ offset_array.GetBuffer<ArrayDataType::kFloat>().data;
+
+ CHECK(mul_float_data.size() == buffer_size);
+ CHECK(add_float_data.size() == buffer_size);
+ CHECK(mean_float_data.size() == buffer_size);
+ CHECK(multiplier_float_data.size() == buffer_size);
+ CHECK(offset_float_data.size() == buffer_size);
+
+ for (int i = 0; i < buffer_size; i++) {
+ mul_float_data[i] = multiplier_float_data[i];
+ add_float_data[i] =
+ offset_float_data[i] - mean_float_data[i] * multiplier_float_data[i];
+ }
+
+ // Remove the old param arrays
+ model->arrays.erase(bn_op->inputs[1]);
+ model->arrays.erase(bn_op->inputs[2]);
+ model->arrays.erase(bn_op->inputs[3]);
+
+ // Remove the old operator
+ DCHECK_EQ(bn_it->get(), bn_op);
+ model->operators.erase(bn_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
new file mode 100644
index 0000000000..53e1be7a05
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -0,0 +1,247 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+std::vector<bool> VectorGreaterThan(const std::vector<int>& a,
+ const std::vector<int>& b) {
+ DCHECK_EQ(a.size(), b.size());
+ const int size = a.size();
+ std::vector<bool> result(size);
+ for (int i = 0; i < size; i++) {
+ result[i] = a[i] > b[i];
+ }
+ return result;
+}
+
+void PairwiseVectorSelect(const std::vector<bool>& selector,
+ const std::vector<int>& input_a,
+ const std::vector<int>& input_b,
+ std::vector<int>* output_a,
+ std::vector<int>* output_b) {
+ DCHECK_EQ(input_a.size(), input_b.size());
+ DCHECK_EQ(output_a->size(), output_b->size());
+ DCHECK_EQ(input_a.size(), output_a->size());
+ DCHECK_EQ(selector.size(), input_a.size());
+ const int size = input_a.size();
+ for (int i = 0; i < size; i++) {
+ if (selector[i]) {
+ (*output_a)[i] = input_a[i];
+ (*output_b)[i] = input_b[i];
+ } else {
+ (*output_a)[i] = input_b[i];
+ (*output_b)[i] = input_a[i];
+ }
+ }
+}
+
+template <ArrayDataType InputsDataType, ArrayDataType OutputDataType>
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[0]));
+ CHECK(IsConstantParameterArray(*model, binary_op->inputs[1]));
+ CHECK(binary_op->fused_activation_function ==
+ FusedActivationFunctionType::kNone);
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ const auto& output_name = binary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ CHECK(input0_array.data_type == InputsDataType);
+ CHECK(input1_array.data_type == InputsDataType);
+ CHECK(output_array.data_type == OutputDataType);
+
+ // We have already tested above for existence of input buffers
+ // (synonymous to being a constant param).
+ CHECK(input0_array.buffer);
+ CHECK(input1_array.buffer);
+ // On the other hand, the output should not already have a buffer.
+ CHECK(!output_array.buffer);
+
+ const auto& input0_data = input0_array.GetBuffer<InputsDataType>().data;
+ const auto& input1_data = input1_array.GetBuffer<InputsDataType>().data;
+ // Create the buffer on the output array, effectively turning it into
+ // a constant parameter
+
+ const Shape& output_shape = output_array.shape();
+ auto& output_data = output_array.GetMutableBuffer<OutputDataType>().data;
+ const int output_buffer_size = RequiredBufferSizeForShape(output_shape);
+ output_data.resize(output_buffer_size);
+ const int dims_count = output_shape.dimensions_count();
+
+ // It will be convenient here to have copies of the operands shapes
+ // extended to match the number of dimensions of the output shape.
+ Shape input0_shape = input0_array.shape();
+ Shape input1_shape = input1_array.shape();
+ ExtendShape(&input0_shape, dims_count);
+ ExtendShape(&input1_shape, dims_count);
+ // Now we may still have operands of different sizes, which would indicate
+ // that we have to "broadcast" the smaller dimension. We do this using a
+ // a vector of Booleans indicating which input is the larger in each
+ // dimension.
+ CHECK_EQ(input0_shape.dimensions_count(), input1_shape.dimensions_count());
+ CHECK_EQ(input0_shape.dimensions_count(), dims_count);
+ const std::vector<bool> input0_larger =
+ VectorGreaterThan(input0_shape.dims(), input1_shape.dims());
+
+ std::vector<int> big_sizes(dims_count);
+ std::vector<int> small_sizes(dims_count);
+ PairwiseVectorSelect(input0_larger, input0_shape.dims(), input1_shape.dims(),
+ &big_sizes, &small_sizes);
+
+ // The output should already be correctly sized to match the big dimensions.
+ for (int i = 0; i < dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), big_sizes[i]);
+ }
+
+ std::vector<int> input0_indices(dims_count);
+ std::vector<int> input1_indices(dims_count);
+ std::vector<int> modulo_indices(dims_count);
+
+ for (int k = 0; k < output_buffer_size; k++) {
+ const std::vector<int> output_indices = ReverseOffset(output_shape, k);
+ for (int i = 0; i < dims_count; i++) {
+ modulo_indices[i] = output_indices[i] % small_sizes[i];
+ }
+ PairwiseVectorSelect(input0_larger, output_indices, modulo_indices,
+ &input0_indices, &input1_indices);
+ const auto val0 = input0_data[Offset(input0_shape, input0_indices)];
+ const auto val1 = input1_data[Offset(input1_shape, input1_indices)];
+
+ DataType<OutputDataType> outval;
+ if (binary_op->type == OperatorType::kAdd) {
+ outval = val0 + val1;
+ } else if (binary_op->type == OperatorType::kMul) {
+ outval = val0 * val1;
+ } else if (binary_op->type == OperatorType::kSub) {
+ outval = val0 - val1;
+ } else if (binary_op->type == OperatorType::kDiv) {
+ outval = val0 / val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowMinimum) {
+ outval = std::min(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowMaximum) {
+ outval = std::max(val0, val1);
+ } else if (binary_op->type == OperatorType::kTensorFlowLess) {
+ outval = val0 < val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowLessEqual) {
+ outval = val0 <= val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreater) {
+ outval = val0 > val1;
+ } else if (binary_op->type == OperatorType::kTensorFlowGreaterEqual) {
+ outval = val0 >= val1;
+ } else {
+ LOG(FATAL) << "should not get here";
+ }
+ output_data[Offset(output_shape, output_indices)] = outval;
+ }
+}
+
+void EvaluateBinaryOperatorOnConstantInputs(Model* model,
+ const Operator* binary_op) {
+ const auto inputs_data_type = model->arrays[binary_op->inputs[0]]->data_type;
+ const auto output_data_type = model->arrays[binary_op->outputs[0]]->data_type;
+#define TOCO_HANDLE_CASE(InputsDataType, OutputDataType) \
+ if (inputs_data_type == InputsDataType && \
+ output_data_type == OutputDataType) { \
+ EvaluateBinaryOperatorOnConstantInputs<InputsDataType, OutputDataType>( \
+ model, binary_op); \
+ return; \
+ }
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kFloat)
+ TOCO_HANDLE_CASE(ArrayDataType::kFloat, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kInt32)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt32, ArrayDataType::kBool)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kInt64)
+ TOCO_HANDLE_CASE(ArrayDataType::kInt64, ArrayDataType::kBool)
+ LOG(FATAL) << "Unimplemented: don't know how to resolve a constant "
+ << "binary operator for these data types.";
+#undef TOCO_HANDLE_CASE
+}
+} // namespace
+
+bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ const auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kTensorFlowMinimum &&
+ binary_op->type != OperatorType::kTensorFlowMaximum &&
+ binary_op->type != OperatorType::kTensorFlowLess &&
+ binary_op->type != OperatorType::kTensorFlowLessEqual &&
+ binary_op->type != OperatorType::kTensorFlowGreater &&
+ binary_op->type != OperatorType::kTensorFlowGreaterEqual) {
+ return false;
+ }
+ CHECK_EQ(binary_op->inputs.size(), 2);
+
+ const auto& input0_array = model->GetArray(binary_op->inputs[0]);
+ const auto& input1_array = model->GetArray(binary_op->inputs[1]);
+ // Check if both inputs are constant parameters.
+ if (!input0_array.buffer || !input1_array.buffer) {
+ return false;
+ }
+
+ auto& output_array = *model->arrays[binary_op->outputs[0]];
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (binary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s because it has a fused activation function",
+ LogName(*binary_op));
+ return false;
+ }
+
+ // Check that input data types agree.
+ CHECK(input0_array.data_type == input1_array.data_type);
+
+ // Do the actual constants propagation
+ EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
+
+ // Remove the binary operator and its inputs
+ if (CountOpsWithInput(*model, binary_op->inputs[0]) == 1) {
+ model->arrays.erase(binary_op->inputs[0]);
+ }
+ if (CountOpsWithInput(*model, binary_op->inputs[1]) == 1) {
+ model->arrays.erase(binary_op->inputs[1]);
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*binary_op));
+ model->operators.erase(binary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
new file mode 100644
index 0000000000..0983c43849
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_concatenation.cc
@@ -0,0 +1,196 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "absl/strings/str_cat.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+// Copies data from multiple source arrays to a destination array based on a
+// concatenation dimension. From each array in input_arrays, it copies chunk
+// sizes provided in array_copy_size vector (per array). It uses the buffer
+// in concatenated_array as destination buffer.
+template <ArrayDataType A, typename T>
+void CopyTensorSegments(const std::vector<Array*>& input_arrays,
+ const std::vector<int>& array_copy_size,
+ const int num_elements_concatenated_array,
+ Array* concatenated_array) {
+ for (Array* input_array : input_arrays) {
+ if (!input_array->buffer) {
+ return;
+ }
+ }
+
+ auto& concatenated_array_buffer =
+ concatenated_array->GetMutableBuffer<A>().data;
+ concatenated_array_buffer.resize(num_elements_concatenated_array);
+
+ // It does not matter which array to use to find the value for the total
+ // number of copy steps.
+ CHECK(!input_arrays.empty());
+ CHECK_NE(array_copy_size[0], 0);
+ const int total_copy_steps =
+ input_arrays[0]->GetBuffer<A>().data.size() / array_copy_size[0];
+
+ // Initialize the source pointers to point to beginning of the array buffers.
+ std::vector<const T*> src_ptr;
+ src_ptr.reserve(input_arrays.size());
+ for (Array* input_array : input_arrays) {
+ src_ptr.push_back(input_array->GetBuffer<A>().data.data());
+ }
+
+ // Copy the data from input_arrays to concatenated_array_buffer.
+ T* dest_ptr = concatenated_array_buffer.data();
+ for (int s = 0; s < total_copy_steps; s++) {
+ for (int i = 0; i < input_arrays.size(); i++) {
+ std::copy(src_ptr[i], src_ptr[i] + array_copy_size[i], dest_ptr);
+ src_ptr[i] += array_copy_size[i];
+ dest_ptr += array_copy_size[i];
+ }
+ }
+}
+
+// Receives a series of input arrays of type Array and an integer showing the
+// axis on which those arrays will be concatenated. It returns the concatenated
+// arrray.
+template <ArrayDataType A>
+void ConcatenateTensorBuffers(const std::vector<Array*>& input_arrays,
+ int concatenation_axis,
+ Array* concatenated_array) {
+ int num_elements_concatenated_array = 1;
+ for (int i = 0; i < concatenated_array->shape().dimensions_count(); i++) {
+ num_elements_concatenated_array *= concatenated_array->shape().dims()[i];
+ }
+ // Prepare the data needed for segmented copy from multiple source arrays to
+ // a destination array based on a oncatenation dimension.
+ std::vector<int> array_copy_size(input_arrays.size());
+ int count = 0;
+ for (Array* input_array : input_arrays) {
+ const Shape array_shape = input_array->shape();
+ array_copy_size[count] = 1;
+ for (int i = concatenation_axis; i < array_shape.dimensions_count(); i++) {
+ array_copy_size[count] *= array_shape.dims()[i];
+ }
+ count++;
+ }
+
+ // Do the actual data copy.
+ CopyTensorSegments<A, DataType<A>>(input_arrays, array_copy_size,
+ num_elements_concatenated_array,
+ concatenated_array);
+}
+
+// Sets the minimum and maximum values for the concatenated array. If it's
+// already set (e.g. because of previous pass in TOCO), it doesn't change it and
+// returns. Otherwise it uses the input arrays min and max values to compute the
+// concatenated array min and max.
+void SetMinMaxForConcatenedArray(const std::vector<Array*>& input_arrays,
+ Array* concatenated_array) {
+ CHECK(concatenated_array->data_type == ArrayDataType::kFloat);
+ // If the minmax is already set, use it
+ if (concatenated_array->minmax) return;
+
+ double concat_min = std::numeric_limits<double>::infinity();
+ double concat_max = -std::numeric_limits<double>::infinity();
+
+ for (Array* input_array : input_arrays) {
+ // If any of the input arrays minmax is not set, return.
+ // TODO(ghodrat): shall we add the logic to compute the minmax?
+ if (!input_array->minmax) return;
+ const MinMax& input_minmax = input_array->GetMinMax();
+ concat_min = std::min(concat_min, input_minmax.min);
+ concat_max = std::max(concat_max, input_minmax.max);
+ }
+ MinMax& minmax = concatenated_array->GetOrCreateMinMax();
+ minmax.min = concat_min;
+ minmax.max = concat_max;
+}
+
+} // namespace
+
+// Resolves the concatenation operator if all its inputs are constant arrays.
+bool ResolveConstantConcatenation::Run(Model* model, std::size_t op_index) {
+ const auto concat_it = model->operators.begin() + op_index;
+ const auto* concat_base_op = concat_it->get();
+ if (concat_base_op->type != OperatorType::kConcatenation) {
+ return false;
+ }
+ const auto* concat_op =
+ static_cast<const ConcatenationOperator*>(concat_base_op);
+
+ for (const string& input_name : concat_op->inputs) {
+ // We only expect constant unquantized arrays as input, otherwise we return.
+ // We also make sure the shapes of the input arrays are known and they are
+ // all discardable.
+ const Operator* input_op = GetOpWithOutput(*model, input_name);
+ if (input_op) return false;
+ if (!IsConstantParameterArray(*model, input_name)) return false;
+ if (!model->GetArray(input_name).has_shape()) return false;
+ if (model->GetArray(input_name).quantization_params) return false;
+ if (!IsDiscardableArray(*model, input_name)) return false;
+ }
+
+ const int concatenation_axis = concat_op->concat_dim;
+
+ CHECK_EQ(concat_op->outputs.size(), 1);
+ string concatenated_array_name = concat_op->outputs[0];
+ Array& concatenated_array = model->GetOrCreateArray(concatenated_array_name);
+ std::vector<Array*> input_arrays;
+ for (const string& input_name : concat_op->inputs) {
+ input_arrays.push_back(&model->GetArray(input_name));
+ }
+
+ switch (concatenated_array.data_type) {
+ case ArrayDataType::kFloat:
+ ConcatenateTensorBuffers<ArrayDataType::kFloat>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ SetMinMaxForConcatenedArray(input_arrays, &concatenated_array);
+ break;
+ case ArrayDataType::kUint8:
+ ConcatenateTensorBuffers<ArrayDataType::kUint8>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt32:
+ ConcatenateTensorBuffers<ArrayDataType::kInt32>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ case ArrayDataType::kInt64:
+ ConcatenateTensorBuffers<ArrayDataType::kInt64>(
+ input_arrays, concatenation_axis, &concatenated_array);
+ break;
+ default:
+ LOG(FATAL) << "ArrayDataType not supported";
+ }
+
+ // Remove all the resolved arrays.
+ for (const string& input_name : concat_op->inputs) {
+ model->arrays.erase(input_name);
+ }
+
+ // Remove concatenate operator
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
new file mode 100644
index 0000000000..244adcc4c4
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_fake_quant.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantFakeQuant::Run(Model* model, std::size_t op_index) {
+ const auto fakequant_it = model->operators.begin() + op_index;
+ const auto* fakequant_base_op = fakequant_it->get();
+ if (fakequant_base_op->type != OperatorType::kFakeQuant) {
+ return false;
+ }
+
+ const auto* fakequant_op =
+ static_cast<const FakeQuantOperator*>(fakequant_base_op);
+
+ // Yield until the fakequant MinMax has been resolved.
+ if (!fakequant_op->minmax) {
+ return false;
+ }
+
+ // This transformation only applies when the input array is constant.
+ if (!IsConstantParameterArray(*model, fakequant_op->inputs[0])) {
+ return false;
+ }
+
+ const auto& input_array = model->GetArray(fakequant_op->inputs[0]);
+ auto& output_array = model->GetArray(fakequant_op->outputs[0]);
+ CHECK(input_array.data_type == ArrayDataType::kFloat);
+ output_array.data_type = ArrayDataType::kFloat;
+ CHECK(!output_array.buffer);
+ const auto& input_buffer = input_array.GetBuffer<ArrayDataType::kFloat>();
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kFloat>();
+ const int size = input_buffer.data.size();
+ output_buffer.data.resize(size);
+ QuantizationParams qparams;
+ GetQuantizationParamsFromMinMax<ArrayDataType::kUint8>(
+ model->flags, *fakequant_op->minmax, &qparams);
+ for (int i = 0; i < size; i++) {
+ const double src_val = input_buffer.data[i];
+ const double unclamped_quantized_val =
+ std::round(qparams.zero_point + src_val / qparams.scale);
+ const double quantized_val =
+ std::min(255., std::max(0., unclamped_quantized_val));
+ const double dst_val = qparams.scale * (quantized_val - qparams.zero_point);
+ output_buffer.data[i] = dst_val;
+ }
+ if (CountOpsWithInput(*model, fakequant_op->inputs[0]) == 1) {
+ model->arrays.erase(fakequant_op->inputs[0]);
+ }
+ model->operators.erase(fakequant_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
new file mode 100644
index 0000000000..8cc6db1619
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_tensorflow_shape.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <cstddef>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantTensorFlowShape::Run(Model* model, std::size_t op_index) {
+ const auto tfshape_it = model->operators.begin() + op_index;
+ const auto* tfshape_base_op = tfshape_it->get();
+ if (tfshape_base_op->type != OperatorType::kTensorFlowShape) {
+ return false;
+ }
+
+ const auto* tfshape_op =
+ static_cast<const TensorFlowShapeOperator*>(tfshape_base_op);
+
+ const auto& input_array = model->GetArray(tfshape_op->inputs[0]);
+ auto& output_array = model->GetArray(tfshape_op->outputs[0]);
+
+ // Yield until the input array's shape has been resolved.
+ if (!input_array.has_shape()) {
+ return false;
+ }
+
+ // Create a buffer for the output array, making it a constant array, and
+ // copy the input shape into the output buffer.
+ CHECK(!output_array.buffer);
+ auto& output_buffer = output_array.GetMutableBuffer<ArrayDataType::kInt32>();
+ output_buffer.data = input_array.shape().dims();
+
+ // Erase the input array if no longer used
+ if (IsDiscardableArray(*model, tfshape_op->inputs[0]) &&
+ CountOpsWithInput(*model, tfshape_op->inputs[0]) == 1) {
+ model->arrays.erase(tfshape_op->inputs[0]);
+ }
+ model->operators.erase(tfshape_it);
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
new file mode 100644
index 0000000000..bb9bda3c82
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -0,0 +1,175 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <algorithm>
+#include <cmath>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
+ const auto unary_it = model->operators.begin() + op_index;
+ const auto* unary_op = unary_it->get();
+ // Test for unary ops of types that we know how to resolve
+ if (unary_op->type != OperatorType::kTensorFlowRsqrt &&
+ unary_op->type != OperatorType::kTensorFlowSqrt &&
+ unary_op->type != OperatorType::kTensorFlowSquare &&
+ unary_op->type != OperatorType::kTensorFlowSum &&
+ unary_op->type != OperatorType::kTensorFlowMin &&
+ unary_op->type != OperatorType::kTensorFlowMax &&
+ unary_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+ // Check if the input is a constant parameter.
+ if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
+ return false;
+ }
+
+ // if the unary op involves a tensor required by a rnn state, ignore it
+ for (const auto& rnn_state : model->flags.rnn_states()) {
+ if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
+ return false;
+ }
+ if (unary_op->inputs[0] == rnn_state.state_array()) {
+ return false;
+ }
+ }
+
+ // At the moment we don't want to care about fused activation functions.
+ // The idea is that we should do the present constants-propagation before
+ // activation functions get fused.
+ if (unary_op->fused_activation_function !=
+ FusedActivationFunctionType::kNone) {
+ AddMessageF(
+ "Not resolving constant %s "
+ " because it has a fused activation function",
+ LogName(*unary_op));
+ return false;
+ }
+ const auto& input_array = model->GetArray(unary_op->inputs[0]);
+ // We have already tested above for existence of buffers (synonymous to being
+ // a constant param).
+ CHECK(input_array.buffer);
+ // At the moment we only support float buffers.
+ if (input_array.buffer->type != ArrayDataType::kFloat) {
+ return false;
+ }
+ const auto& input_float_data =
+ input_array.GetBuffer<ArrayDataType::kFloat>().data;
+ // Create the float buffer on the output array, effectively turning it into
+ // a constant parameter
+ const auto& output_name = unary_op->outputs[0];
+ auto& output_array = model->GetArray(output_name);
+ // Yield until the output array dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+
+ int input_buffer_size = RequiredBufferSizeForShape(input_array.shape());
+ int output_buffer_size = RequiredBufferSizeForShape(output_array.shape());
+ const Shape& input_shape = input_array.shape();
+ const Shape& output_shape = output_array.shape();
+
+ auto& output_float_data =
+ output_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ output_float_data.resize(output_buffer_size);
+
+ const int output_dims_count = output_shape.dimensions_count();
+ if (unary_op->type == OperatorType::kTensorFlowReshape) {
+ CHECK(input_buffer_size == output_buffer_size);
+ memcpy(output_float_data.data(), input_float_data.data(),
+ input_buffer_size * sizeof(input_float_data[0]));
+ } else if (unary_op->type == OperatorType::kTensorFlowSum) {
+ // At the moment only full reduction across all dimensions is supported.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float sum = 0.f;
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ sum += input_float_data[i];
+ }
+ output_float_data[0] = sum;
+ } else if (unary_op->type == OperatorType::kTensorFlowMin) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float min = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ min = std::min(min, input_float_data[i]);
+ }
+ output_float_data[0] = min;
+ } else if (unary_op->type == OperatorType::kTensorFlowMax) {
+ // At the moment only full reduction across all dimensions is supported.
+ // TODO(starka): Output should not be padded.
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), 1);
+ }
+ float max = input_float_data[0];
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < input_size; i++) {
+ max = std::max(max, input_float_data[i]);
+ }
+ output_float_data[0] = max;
+ } else if (unary_op->type == OperatorType::kTensorFlowRsqrt ||
+ unary_op->type == OperatorType::kTensorFlowSqrt ||
+ unary_op->type == OperatorType::kTensorFlowSquare) {
+ // Element-wise ops. Should have perfectly matching sizes here.
+ const int input_size = RequiredBufferSizeForShape(input_shape);
+ for (int i = 0; i < output_dims_count; i++) {
+ CHECK_EQ(output_shape.dims(i), input_shape.dims(i));
+ }
+
+ for (int i = 0; i < input_size; i++) {
+ const float val = input_float_data[i];
+ float outval = 0.f;
+ if (unary_op->type == OperatorType::kTensorFlowRsqrt) {
+ outval = 1.0f / std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSqrt) {
+ outval = std::sqrt(val);
+ } else if (unary_op->type == OperatorType::kTensorFlowSquare) {
+ outval = val * val;
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ output_float_data[i] = outval;
+ }
+ } else {
+ LOG(FATAL) << "should not get here.";
+ }
+ for (const auto& input : unary_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1) {
+ model->arrays.erase(input);
+ }
+ }
+ AddMessageF("Resolved constant %s to the equivalent constant array",
+ LogName(*unary_op));
+ model->operators.erase(unary_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
new file mode 100644
index 0000000000..d25c773f19
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_mean_attributes.cc
@@ -0,0 +1,51 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveMeanAttributes::Run(Model* model, std::size_t op_index) {
+ auto* mean_op = model->operators[op_index].get();
+ if (mean_op->type != OperatorType::kMean) return false;
+ auto* op = static_cast<MeanOperator*>(mean_op);
+
+ if (!op->reduction_indices.empty()) return false;
+ if (op->inputs.size() != 2) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& indices_array = *model->arrays[op->inputs[1]];
+ if (!indices_array.has_shape()) return false;
+
+ op->reduction_indices = indices_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // At the moment, we only support simultaneous reduction over width and
+ // height. This is mainly limited by the fact that currently, the runtime
+ // arrays are always 4-dimensional.
+ CHECK_EQ(op->reduction_indices.size(), 2);
+ CHECK((op->reduction_indices[0] == 1 && op->reduction_indices[1] == 2) ||
+ (op->reduction_indices[0] == 2 && op->reduction_indices[1] == 1));
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
new file mode 100644
index 0000000000..d5f5869c62
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_pad_attributes.cc
@@ -0,0 +1,55 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolvePadAttributes::Run(Model* model, std::size_t op_index) {
+ const auto pad_it = model->operators.begin() + op_index;
+ auto* pad_op = pad_it->get();
+ if (pad_op->type != OperatorType::kPad) return false;
+
+ auto* op = static_cast<PadOperator*>(pad_op);
+ if (!op->left_padding.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 2);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ const auto& array = *model->arrays[op->inputs[1]];
+ if (!array.has_shape()) return false;
+
+ const std::vector<int>& dims = array.shape().dims();
+ CHECK_EQ(dims.size(), 2);
+
+ std::vector<int> buffer = array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ for (int i = 0; i < dims[0]; ++i) {
+ op->left_padding.push_back(buffer[i * 2]);
+ op->right_padding.push_back(buffer[i * 2 + 1]);
+ }
+
+ // TODO(dkalenichenko): Delete the extra input?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
new file mode 100644
index 0000000000..8fa7b83bed
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reorder_axes.cc
@@ -0,0 +1,93 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReorderAxes::Run(Model* model, std::size_t op_index) {
+ auto reorder_it = model->operators.begin() + op_index;
+ auto* reorder_op = static_cast<ReorderAxesOperator*>(reorder_it->get());
+ if (reorder_op->type != OperatorType::kReorderAxes) {
+ return false;
+ }
+ const auto& input_array_name = reorder_op->inputs[0];
+ const auto& output_array_name = reorder_op->outputs[0];
+ auto& input_array = model->GetArray(input_array_name);
+ auto& output_array = model->GetArray(output_array_name);
+ string constant_input_array_name = input_array_name;
+ if (!input_array.buffer) {
+ const auto* op_producing_input = GetOpWithOutput(*model, input_array_name);
+ if (op_producing_input &&
+ op_producing_input->type == OperatorType::kFakeQuant) {
+ constant_input_array_name = op_producing_input->inputs[0];
+ }
+ }
+ auto& constant_input_array = model->GetArray(constant_input_array_name);
+ if (!constant_input_array.buffer) {
+ return false;
+ }
+ // Yield until output dims have been resolved.
+ if (!output_array.has_shape()) {
+ return false;
+ }
+ // Reorder the input array dims and buffer data
+ CHECK(constant_input_array.buffer->type == ArrayDataType::kFloat);
+ CHECK(!output_array.buffer);
+ auto& input_data =
+ constant_input_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
+ std::vector<float> reordered_data;
+ reordered_data.resize(RequiredBufferSizeForShape(output_array.shape()));
+ const auto input_axes_order = reorder_op->input_axes_order;
+ const auto output_axes_order = reorder_op->output_axes_order;
+ // TODO(b/62904716) Shapes should be used directly.
+ Shape input_shape = constant_input_array.shape();
+ Shape output_shape = output_array.shape();
+ if (AxesCount(input_axes_order) == 2) {
+ UnextendShape(&input_shape, 2);
+ UnextendShape(&output_shape, 2);
+ }
+ ShuffleArray(input_shape, input_axes_order, output_axes_order, output_shape,
+ input_data.data(), reordered_data.data());
+ input_data = reordered_data;
+ input_array.copy_shape(output_array.shape());
+ constant_input_array.copy_shape(output_array.shape());
+
+ // Update the edges of the graph to point to the input array
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == output_array_name) {
+ input = input_array_name;
+ }
+ }
+ }
+
+ AddMessageF("Reordered axes for array %s", input_array_name);
+
+ // Remove the op and output array.
+ model->arrays.erase(output_array_name);
+ model->operators.erase(reorder_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
new file mode 100644
index 0000000000..bed2a85bd2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_reshape_attributes.cc
@@ -0,0 +1,49 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <iterator>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveReshapeAttributes::Run(Model* model, std::size_t op_index) {
+ const auto reshape_it = model->operators.begin() + op_index;
+ auto* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kTensorFlowReshape) {
+ return false;
+ }
+
+ auto* op = static_cast<TensorFlowReshapeOperator*>(reshape_op);
+
+ if (!op->shape.empty()) return false;
+
+ if (IsConstantParameterArray(*model, reshape_op->inputs[1])) {
+ const auto& constant_input_array = *model->arrays[reshape_op->inputs[1]];
+ op->shape = constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
+ }
+
+ if (op->shape.empty()) return false;
+
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
new file mode 100644
index 0000000000..1d0a2ec8f6
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_slice_attributes.cc
@@ -0,0 +1,52 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kSlice) return false;
+
+ auto* op = static_cast<SliceOperator*>(slice_op);
+ if (!op->begin.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 3);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+
+ const auto& begin_array = *model->arrays[op->inputs[1]];
+ if (!begin_array.has_shape()) return false;
+
+ const auto& size_array = *model->arrays[op->inputs[2]];
+ if (!size_array.has_shape()) return false;
+
+ op->begin = begin_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->size = size_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
new file mode 100644
index 0000000000..5fc3b25bc1
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -0,0 +1,62 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
+ const auto slice_it = model->operators.begin() + op_index;
+ auto* slice_op = slice_it->get();
+ if (slice_op->type != OperatorType::kStridedSlice) return false;
+
+ auto* op = static_cast<StridedSliceOperator*>(slice_op);
+ if (!op->start_indices.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 4);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+
+ const auto& start_array = *model->arrays[op->inputs[1]];
+ if (!start_array.has_shape()) return false;
+
+ const auto& stop_array = *model->arrays[op->inputs[2]];
+ if (!stop_array.has_shape()) return false;
+
+ const auto& stride_array = *model->arrays[op->inputs[3]];
+ if (!stride_array.has_shape()) return false;
+
+ op->start_indices = start_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
+ op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
+
+ // Only 4D arrays are supported for now.
+ CHECK_EQ(op->start_indices.size(), 4);
+ CHECK_EQ(op->stop_indices.size(), 4);
+ CHECK_EQ(op->strides.size(), 4);
+
+ // TODO(dkalenichenko): Delete the extra inputs?
+
+ return true;
+}
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
new file mode 100644
index 0000000000..b482f5cf51
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_concat.cc
@@ -0,0 +1,86 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowConcat::Run(Model* model, std::size_t op_index) {
+ auto concat_it = model->operators.begin() + op_index;
+ const auto* tf_concat_op = concat_it->get();
+ if (tf_concat_op->type != OperatorType::kTensorFlowConcat &&
+ tf_concat_op->type != OperatorType::kTensorFlowConcatV2) {
+ return false;
+ }
+
+ CHECK_GE(tf_concat_op->inputs.size(), 2);
+ // TensorFlow Concat and ConcatV2 nodes only differ by the ordering
+ // of inputs: in Concat, the concat_dim is the first input, while in
+ // ConcatV2, it is the last input.
+ std::size_t concat_dim_pos = 0;
+ if (tf_concat_op->type == OperatorType::kTensorFlowConcatV2) {
+ concat_dim_pos = tf_concat_op->inputs.size() - 1;
+ }
+ const string concat_dim_name = tf_concat_op->inputs[concat_dim_pos];
+ std::vector<string> concat_input_names;
+ for (std::size_t i = 0; i < tf_concat_op->inputs.size(); i++) {
+ if (i != concat_dim_pos) {
+ concat_input_names.push_back(tf_concat_op->inputs[i]);
+ }
+ }
+ // If the concat_dim array hasn't been resolved to a constant yet,
+ // we need to yield.
+ const auto& concat_dim_array = model->GetArray(concat_dim_name);
+ if (!concat_dim_array.buffer) {
+ AddMessageF("Waiting for the concat_dim of %s to be resolved to a constant",
+ LogName(*tf_concat_op));
+ return false;
+ }
+
+ CHECK(concat_dim_array.data_type == ArrayDataType::kInt32);
+ const auto& concat_dim_data =
+ concat_dim_array.GetBuffer<ArrayDataType::kInt32>().data;
+ CHECK_EQ(concat_dim_data.size(), 1);
+ const int concat_dim = concat_dim_data[0];
+
+ // Create the Concatenation op replacing the TensorFlowConcat op.
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {tf_concat_op->outputs[0]};
+ auto depth_concat_it = model->operators.emplace(concat_it, concatenation_op);
+ CHECK_EQ(depth_concat_it->get(), concatenation_op);
+ // Update invalidated iterator
+ concat_it = depth_concat_it + 1;
+ CHECK_EQ(concat_it->get(), tf_concat_op);
+
+ // Remove the concat_dim array if it is not used by anything else.
+ if (CountOpsWithInput(*model, concat_dim_name) == 1) {
+ model->arrays.erase(concat_dim_name);
+ }
+ // Remove the TensorFlowConcat op
+ model->operators.erase(concat_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
new file mode 100644
index 0000000000..bea7487051
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_matmul.cc
@@ -0,0 +1,106 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMatMul::Run(Model* model, std::size_t op_index) {
+ auto matmul_it = model->operators.begin() + op_index;
+ if (matmul_it->get()->type != OperatorType::kTensorFlowMatMul) {
+ return false;
+ }
+ const auto* matmul_op = matmul_it->get();
+
+ // Find the op producing the array passed to this MatMul
+ auto previous_op_it = model->operators.begin();
+ bool found = false;
+ for (; previous_op_it != model->operators.end(); ++previous_op_it) {
+ for (const auto& output : (*previous_op_it)->outputs) {
+ if (output == matmul_op->inputs[0]) {
+ found = true;
+ break;
+ }
+ }
+ if (found) {
+ break;
+ }
+ }
+ Operator* previous_op = (found) ? previous_op_it->get() : nullptr;
+
+ // construct the new FullyConnectedOperator
+ auto* fc_op = new FullyConnectedOperator;
+ fc_op->outputs = matmul_op->outputs;
+
+ // insert the newly constructed FullyConnectedOperator
+ auto fc_it = model->operators.emplace(matmul_it, fc_op);
+
+ // refresh invalidated iterator
+ matmul_it = fc_it + 1;
+ DCHECK_EQ(matmul_it->get(), matmul_op);
+
+ // The way that TensorFlow encodes FullyConnected ops is as a pair
+ // (Reshape, MatMul), so we want to remove the Reshape op and rewrite the
+ // MatMul
+ // op as a FullyConnected. However, TensorFlow skips the Reshape ops if the
+ // input doesn't need reshaping, so we can't just match (Reshape, MatMul)
+ // pairs.
+ if (previous_op && previous_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF("Combining %s and %s into %s", LogName(*previous_op),
+ LogName(*matmul_op), LogName(*fc_op));
+ const auto& previous_op_output = previous_op->outputs[0];
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ model->arrays.erase(previous_op_output);
+ }
+ CHECK_EQ(previous_op->inputs.size(), 2);
+ fc_op->inputs = {previous_op->inputs[0], matmul_op->inputs[1]};
+ // Only remove Reshape node if no other node uses its output.
+ if (CountOpsWithInput(*model, previous_op_output) == 1) {
+ const auto& previous_op_shape = previous_op->inputs[1];
+ if (CountOpsWithInput(*model, previous_op_shape) == 1 &&
+ !GetOpWithOutput(*model, previous_op_shape)) {
+ model->arrays.erase(previous_op_shape);
+ }
+ model->operators.erase(previous_op_it);
+ }
+
+ // We may have just invalidated matmul_it, so let's refresh it now.
+ matmul_it = model->operators.begin();
+ for (; matmul_it != model->operators.end(); ++matmul_it) {
+ if (matmul_it->get() == matmul_op) {
+ break;
+ }
+ }
+ CHECK(matmul_it != model->operators.end());
+ CHECK(matmul_it->get() == matmul_op);
+ } else {
+ AddMessageF("Replacing %s by a FullyConnected operator",
+ LogName(*matmul_op));
+ fc_op->inputs = {matmul_op->inputs[0], matmul_op->inputs[1]};
+ }
+
+ // erase the MatMul operator
+ model->operators.erase(matmul_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
new file mode 100644
index 0000000000..cfa5ce0716
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_merge.cc
@@ -0,0 +1,63 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowMerge::Run(Model* model, std::size_t op_index) {
+ const auto merge_it = model->operators.begin() + op_index;
+ const auto* merge_op = merge_it->get();
+ if (merge_op->type != OperatorType::kTensorFlowMerge) {
+ return false;
+ }
+
+ // We need to yield until this Merge node has only 1 input, which will mean
+ // that that is the selected input. Other graph transformations on other nodes
+ // such as ResolveTensorFlowSwitch, will take care of trimming the
+ // non-selected inputs, so that at some point there will be only 1 input left.
+ if (merge_op->inputs.size() > 1) {
+ AddMessageF("Waiting for %s to be resolved", LogName(*merge_op));
+ return false;
+ }
+
+ // Now that the merge node has 1 input exactly, it is the same as an Identity
+ // node and can be resolved trivially.
+ CHECK_EQ(merge_op->inputs.size(), 1);
+
+ // Update the edges of the graph ahead of removing the node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == merge_op->outputs[0]) {
+ input = merge_op->inputs[0];
+ }
+ }
+ }
+
+ // Remove the node and its output array.
+ AddMessageF("Removing already-resolved %s", LogName(*merge_op));
+ model->arrays.erase(merge_op->outputs[0]);
+ model->operators.erase(merge_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
new file mode 100644
index 0000000000..1d3f42b5ec
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_squeeze.cc
@@ -0,0 +1,54 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSqueeze::Run(Model* model, std::size_t op_index) {
+ const auto squeeze_it = model->operators.begin() + op_index;
+ const auto* squeeze_op = squeeze_it->get();
+ if (squeeze_op->type != OperatorType::kSqueeze) {
+ return false;
+ }
+
+ CHECK_EQ(squeeze_op->inputs.size(), 1);
+ CHECK_EQ(squeeze_op->outputs.size(), 1);
+
+ // If the output is consumed by a reshape op, it's a trivial squeeze.
+ if (CountOpsWithInput(*model, squeeze_op->outputs[0]) == 1) {
+ const auto* next_op = GetOpWithInput(*model, squeeze_op->outputs[0]);
+ if (next_op->type == OperatorType::kTensorFlowReshape) {
+ AddMessageF(
+ "%s is trivial because its output is only consumed by a "
+ "Reshape op",
+ LogName(*squeeze_op));
+
+ return RemoveTrivialPassthroughOp(this, model, op_index);
+ }
+ }
+
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
new file mode 100644
index 0000000000..55adfca037
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_switch.cc
@@ -0,0 +1,123 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTensorFlowSwitch::Run(Model* model, std::size_t op_index) {
+ const auto switch_it = model->operators.begin() + op_index;
+ const auto* switch_op = switch_it->get();
+ if (switch_op->type != OperatorType::kTensorFlowSwitch) {
+ return false;
+ }
+
+ CHECK_EQ(switch_op->inputs.size(), 2);
+ CHECK_EQ(switch_op->outputs.size(), 2);
+ const string& predicate_name = switch_op->inputs[1];
+ // If the predicate array hasn't been resolved to a constant yet,
+ // we need to yield.
+ if (!IsConstantParameterArray(*model, predicate_name)) {
+ AddMessageF(
+ "Waiting for the boolean predicate of %s to be resolved to a constant",
+ LogName(*switch_op));
+ return false;
+ }
+
+ // The predicate should be boolean, and should consist of a single value.
+ const auto& predicate_array = model->GetArray(predicate_name);
+ CHECK(predicate_array.data_type == ArrayDataType::kBool);
+ for (const auto& dim : predicate_array.shape().dims()) {
+ CHECK_EQ(dim, 1);
+ }
+
+ // Obtain the predicate boolean value.
+ const auto& predicate_data =
+ predicate_array.GetBuffer<ArrayDataType::kBool>().data;
+ CHECK_EQ(predicate_data.size(), 1);
+ const bool predicate_value = predicate_data[0];
+
+ // From the TensorFlow docs on .switch() in
+ // third_party/tensorflow/python/ops/control_flow_ops.py
+ //
+ // If `pred` is false, the `data` input is forwared to the first output.
+ // Otherwise, the data goes to the second output.
+ //
+ // Note that this comment used to say the opposite and was recently fixed:
+ // https://github.com/tensorflow/tensorflow/commit/bc456e361d49d1d89a74b80060c70efb51fd7d87#diff-76ab9dafbe12c20ddc3769c6b108986c
+ const int selected_output_index = predicate_value ? 1 : 0;
+ const int nonselected_output_index = predicate_value ? 0 : 1;
+
+ // Update the edges of the graph ahead of removing the node:
+ // edges that were pointing to the selected output, should instead
+ // point to the input of the Switch node.
+ for (const auto& other_op : model->operators) {
+ for (auto& input : other_op->inputs) {
+ if (input == switch_op->outputs[selected_output_index]) {
+ input = switch_op->inputs[0];
+ }
+ }
+ }
+
+ // There remains to handle the edges that were pointing to the nonselected
+ // output. We will just discard those edges. Concretely, at the moment,
+ // our only examples of graphs with Switch nodes have them feeding into Merge
+ // nodes, so what we're saying here is that we'll make the convention,
+ // in our toco internal representation, that Merge nodes with only 1 input
+ // are Merge nodes that have been resolved already and should be have as
+ // Identity nodes, simply forwarding their input.
+ //
+ for (const auto& other_op : model->operators) {
+ auto input_it = other_op->inputs.begin();
+ while (input_it != other_op->inputs.end()) {
+ if (*input_it == switch_op->outputs[nonselected_output_index]) {
+ // Let us guard our assumption that only Merge nodes consume the outputs
+ // of Switch nodes:
+ CHECK(other_op->type == OperatorType::kTensorFlowMerge);
+ input_it = other_op->inputs.erase(input_it);
+ } else {
+ ++input_it;
+ }
+ }
+ }
+
+ // Remove the output arrays if they are now unused.
+ for (int i = 0; i < 2; i++) {
+ if (!GetOpWithInput(*model, switch_op->outputs[i])) {
+ model->arrays.erase(switch_op->outputs[i]);
+ }
+ }
+ // Remove input arrays if they are only used by the switch itself and aren't
+ // the output of another op (will get handled by RemoveUnusedOp in that case).
+ for (const auto& input : switch_op->inputs) {
+ if (CountOpsWithInput(*model, input) == 1 &&
+ !GetOpWithOutput(*model, input)) {
+ model->arrays.erase(input);
+ }
+ }
+ // Remove the switch node itself.
+ AddMessageF("Removing already-resolved %s", LogName(*switch_op));
+ model->operators.erase(switch_it);
+ return true;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
new file mode 100644
index 0000000000..9f7e7c42a2
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_tensorflow_tile.cc
@@ -0,0 +1,97 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+void RemoveTileOperator(Model* model, Operator* tile_op, Operator* binary_op,
+ int operand_index) {
+ CHECK(tile_op->type == OperatorType::kTensorFlowTile);
+ CHECK_EQ(binary_op->inputs.size(), 2);
+ CHECK_EQ(tile_op->inputs.size(), 2);
+ const string tile_multiplier_array = tile_op->inputs[1];
+ const string tile_output_array = tile_op->outputs[0];
+ binary_op->inputs[operand_index] = tile_op->inputs[0];
+ auto tile_it = model->operators.begin();
+ for (; tile_it != model->operators.end(); ++tile_it) {
+ if (tile_it->get() == tile_op) {
+ break;
+ }
+ }
+ CHECK(tile_it != model->operators.end());
+ CHECK(tile_it->get() == tile_op);
+ model->operators.erase(tile_it);
+ if (!CountOpsWithInput(*model, tile_multiplier_array) &&
+ !GetOpWithOutput(*model, tile_multiplier_array)) {
+ model->arrays.erase(tile_multiplier_array);
+ }
+ if (!CountOpsWithInput(*model, tile_output_array)) {
+ model->arrays.erase(tile_output_array);
+ }
+}
+} // namespace
+
+bool ResolveTensorFlowTile::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ auto* binary_op = binary_it->get();
+ // Test for binary ops of types that we know how to resolve
+ if (binary_op->inputs.size() != 2) {
+ return false;
+ }
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv) {
+ return false;
+ }
+
+ Operator* const op[2] = {
+ GetOpWithOutput(*model, binary_op->inputs[0]),
+ GetOpWithOutput(*model, binary_op->inputs[1]),
+ };
+
+ // In the unlikely case where both operands are Tile, we can't infer the
+ // output
+ // size without the Tile nodes, so we have to bail out.
+ if (op[0] && op[0]->type == OperatorType::kTensorFlowTile && op[1] &&
+ op[1]->type == OperatorType::kTensorFlowTile) {
+ return false;
+ }
+
+ for (int i = 0; i < 2; i++) {
+ if (op[i] && op[i]->type == OperatorType::kTensorFlowTile) {
+ // We can only remove a Tile operator is no other op than the present
+ // binary op was consuming its tiled output.
+ if (CountOpsWithInput(*model, binary_op->inputs[i]) == 1) {
+ AddMessageF("Removing %s", LogName(*op[i]));
+ RemoveTileOperator(model, op[i], binary_op, i);
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
new file mode 100644
index 0000000000..8931498782
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/BUILD
@@ -0,0 +1,31 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+tf_cc_test(
+ name = "resolve_constant_concatenation_test",
+ srcs = ["resolve_constant_concatenation_test.cc"],
+ deps = [
+ "//tensorflow/contrib/lite/toco:graph_transformations",
+ "//tensorflow/contrib/lite/toco:model",
+ "//tensorflow/contrib/lite/toco:tooling_util",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
new file mode 100644
index 0000000000..c6705ad305
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/tests/resolve_constant_concatenation_test.cc
@@ -0,0 +1,221 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+//#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+// A gmock matcher that check that elements of a float vector match to a given
+// tolerance.
+std::vector<testing::Matcher<float>> ArrayFloatNear(
+ const std::vector<float>& values, float max_abs_error = 1e-5) {
+ std::vector<testing::Matcher<float>> matchers;
+ matchers.reserve(values.size());
+ for (const float& v : values) {
+ matchers.emplace_back(testing::FloatNear(v, max_abs_error));
+ }
+ return matchers;
+}
+} // namespace
+
+// The following 3 tests make sure the concatenation operation on different axis
+// values match TensorFlow results listed below:
+//
+// x0 = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
+// x1 = [[[10, 11], [12, 13]], [[14, 15], [16, 17]]]
+// x2 = [[[20, 21], [22, 23]], [[24, 25], [26, 27]]]
+// x3 = [[[30, 31], [32, 33]], [[34, 35], [36, 37]]]
+//
+// ConcatAtAxis0 test:
+// t0 = tf.concat([x0, x1, x2, x3], 0)
+// [[[ 0 1]
+// [ 2 3]]
+//
+// [[ 4 5]
+// [ 6 7]]
+//
+// [[10 11]
+// [12 13]]
+//
+// [[14 15]
+// [16 17]]
+//
+// [[20 21]
+// [22 23]]
+//
+// [[24 25]
+// [26 27]]
+//
+// [[30 31]
+// [32 33]]
+//
+// [[34 35]
+// [36 37]]]
+//
+// ConcatAtAxis1 test:
+// t1 = tf.concat([x0, x1, x2, x3], 1)
+// [[[ 0 1]
+// [ 2 3]
+// [10 11]
+// [12 13]
+// [20 21]
+// [22 23]
+// [30 31]
+// [32 33]]
+//
+// [[ 4 5]
+// [ 6 7]
+// [14 15]
+// [16 17]
+// [24 25]
+// [26 27]
+// [34 35]
+// [36 37]]]
+//
+// ConcatAtAxis2 test:
+// t2 = tf.concat([x0, x1, x2, x3], 2)
+// [[[ 0 1 10 11 20 21 30 31]
+// [ 2 3 12 13 22 23 32 33]]
+//
+// [[ 4 5 14 15 24 25 34 35]
+// [ 6 7 16 17 26 27 36 37]]]
+
+class ResolveConstantConcatenationTest : public ::testing::Test {
+ protected:
+ ResolveConstantConcatenationTest() {}
+
+ // Prepare a hypothetical TOCO model with one Concatenation operator in it
+ // together with 4 arrays as its inputs.
+ // It receives the dimension of concatenation as input.
+ void PrepareModel(Model* model, int concat_dim) {
+ std::vector<string> concat_input_names = {"array0", "array1", "array2",
+ "array3"};
+
+ const int kDim = 3;
+ const int kElementPerDim = 2;
+ const int kBufSize = 8;
+ const int kNumArrays = 4;
+ static float in_buf[kNumArrays][kBufSize] = {
+ {0., 1., 2., 3., 4., 5., 6., 7.},
+ {10., 11., 12., 13., 14., 15., 16., 17.},
+ {20., 21., 22., 23., 24., 25., 26., 27.},
+ {30., 31., 32., 33., 34., 35., 36., 37.}};
+ int cnt = 0;
+ for (const string& concat_input_name : concat_input_names) {
+ Array& in_array = model->GetOrCreateArray(concat_input_name);
+ in_array.data_type = ArrayDataType::kFloat;
+
+ // Initialize shape for the input array.
+ Shape* in_array_shape = in_array.mutable_shape();
+ std::vector<int>* in_array_shape_dim = in_array_shape->mutable_dims();
+ for (int i = 0; i < kDim; i++) {
+ in_array_shape_dim->push_back(kElementPerDim);
+ }
+ auto& in_array_buffer =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>();
+ in_array_buffer.data.resize(kBufSize);
+ float* buf_ptr =
+ in_array.GetMutableBuffer<toco::ArrayDataType::kFloat>().data.data();
+ std::copy(in_buf[cnt], in_buf[cnt] + kBufSize, buf_ptr);
+ cnt++;
+ }
+ auto* concatenation_op = new ConcatenationOperator;
+ concatenation_op->concat_dim = concat_dim;
+ concatenation_op->inputs = concat_input_names;
+ concatenation_op->outputs = {"concat_op_outputs"};
+ Array& out_array = model->GetOrCreateArray(concatenation_op->outputs[0]);
+ out_array.data_type = ArrayDataType::kFloat;
+ Shape* out_array_shape = out_array.mutable_shape();
+ std::vector<int>* out_array_shape_dim = out_array_shape->mutable_dims();
+ out_array_shape_dim->resize(kDim);
+ for (int i = 0; i < kDim; i++) {
+ if (i == concat_dim) {
+ (*out_array_shape_dim)[i] = kNumArrays * kElementPerDim;
+ } else {
+ (*out_array_shape_dim)[i] = kElementPerDim;
+ }
+ }
+ model->operators.push_back(std::unique_ptr<Operator>(concatenation_op));
+ }
+};
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis0) {
+ Model model;
+ const int concat_dim = 0;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12.,
+ 13., 14., 15., 16., 17., 20., 21., 22., 23., 24., 25.,
+ 26., 27., 30., 31., 32., 33., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis1) {
+ Model model;
+ const int concat_dim = 1;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 2., 3., 10., 11., 12., 13., 20., 21., 22.,
+ 23., 30., 31., 32., 33., 4., 5., 6., 7., 14., 15.,
+ 16., 17., 24., 25., 26., 27., 34., 35., 36., 37.})));
+}
+
+TEST_F(ResolveConstantConcatenationTest, ConcatAtAxis2) {
+ Model model;
+ const int concat_dim = 2;
+ PrepareModel(&model, concat_dim);
+
+ GraphTransformationsSet graph_transformation_set;
+ graph_transformation_set.Add(new toco::ResolveConstantConcatenation);
+ EXPECT_THAT(model.arrays.size(), 5);
+ (*graph_transformation_set.begin())->Run(&model, /*op_index=*/0);
+ EXPECT_THAT(model.arrays.size(), 1);
+
+ auto& concatenated_array = (*model.arrays.begin()).second;
+ EXPECT_THAT(concatenated_array->GetBuffer<toco::ArrayDataType::kFloat>().data,
+ ElementsAreArray(ArrayFloatNear(
+ {0., 1., 10., 11., 20., 21., 30., 31., 2., 3., 12.,
+ 13., 22., 23., 32., 33., 4., 5., 14., 15., 24., 25.,
+ 34., 35., 6., 7., 16., 17., 26., 27., 36., 37.})));
+}
+
+} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
new file mode 100644
index 0000000000..4e273343df
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/unfuse_activation_functions.cc
@@ -0,0 +1,73 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool UnfuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+ const auto it = model->operators.begin() + op_index;
+ auto* op = it->get();
+
+ // If a conv operation has an im2col array, yield: it should be dropped first.
+ if ((op->type == OperatorType::kConv) && (op->outputs.size() == 2)) {
+ return false;
+ }
+
+ Operator* ac_op = nullptr;
+ switch (op->fused_activation_function) {
+ case FusedActivationFunctionType::kRelu:
+ ac_op = new ReluOperator;
+ break;
+ case FusedActivationFunctionType::kRelu6:
+ ac_op = new Relu6Operator;
+ break;
+ case FusedActivationFunctionType::kRelu1:
+ ac_op = new Relu1Operator;
+ break;
+ default:
+ return false;
+ }
+
+ // At this point we know that the op has a fused activation function. At the
+ // moment that only happens with ops having a single output, may be
+ // relaxed in the future.
+ CHECK_EQ(op->outputs.size(), 1);
+
+ // Emplace unfused activation function, drop the fused one.
+ model->operators.emplace(it + 1, ac_op);
+ op->fused_activation_function = FusedActivationFunctionType::kNone;
+
+ // Wire up arrays, constructing a new intermediate array to connect the
+ // op to its new unfused activation function.
+ ac_op->outputs = op->outputs;
+ const string& tmp_array_name =
+ AvailableArrayName(*model, op->outputs[0] + "_unfused");
+ CHECK(!model->arrays.count(tmp_array_name));
+ model->GetOrCreateArray(tmp_array_name);
+ ac_op->inputs = {tmp_array_name};
+ op->outputs = {tmp_array_name};
+ return true;
+}
+
+} // namespace toco