aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-03 10:35:00 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 10:37:47 -0700
commit895566e3af6e3e44eb22d0c2d20b4890e42982fc (patch)
tree46a3b5fe397a5afeacc5c927bd5203218be2d458 /tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
parentf005999d571536b859229229e4487cf749d9a786 (diff)
Three operators are added that can be used to decrease the number of
transpose/reshape occurences. These operators are: - Swap elementwise - Swap reshape-transpose - Fuse transpose-reshape Swap elementwise groups operations that operate on value, and move values. This allows all movement operations (reshape, transpose, etc) to group at one portion of the graph while value manipulations group on the other end. Swap reshape/transpose finds cases where the reshape-transpose can be reorderd. This allows the reshape-transpose-reshape case to be transformed to transpose-reshape-reshape, which can then merge the reshape-reshape. Finally the transpose-reshape merge allows cases where, when reshape maintains the same dimensions and does not affect memory ordering, the transpose and reshape can be merged into a single transpose operator. PiperOrigin-RevId: 191462038
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc190
1 files changed, 190 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
new file mode 100644
index 0000000000..5065004093
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/merge_reshape_into_preceding_transpose.cc
@@ -0,0 +1,190 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/graph_transformations/remove_trivial_passthrough.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/runtime/types.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+namespace {
+
+bool OperatorReady(const Model& model, const Operator* op) {
+ if (!model.HasArray(op->inputs[0]) || !model.HasArray(op->inputs[1]) ||
+ !model.HasArray(op->outputs[0])) {
+ // Arrays are missing.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[0]).has_shape() ||
+ !model.GetArray(op->outputs[0]).has_shape()) {
+ // Input and output needs the shape.
+ return false;
+ }
+
+ if (!model.GetArray(op->inputs[1]).buffer) {
+ // Buffer needs to be a constant.
+ return false;
+ }
+
+ return true;
+}
+
+// Returns whether the reshape could be a transpose.
+std::vector<int32> ReshapeToTranspose(const Model& model,
+ const TensorFlowReshapeOperator* op) {
+ CHECK(!op->shape.empty());
+ CHECK(model.HasArray(op->inputs[0]));
+ CHECK(model.HasArray(op->outputs[0]));
+
+ const auto& input_array = model.GetArray(op->inputs[0]);
+ const auto& output_array = model.GetArray(op->outputs[0]);
+
+ CHECK(input_array.has_shape());
+ CHECK(output_array.has_shape());
+
+ std::vector<int> in_shape = input_array.shape().dims();
+ std::vector<int> out_shape = output_array.shape().dims();
+
+ std::vector<int> one_indices;
+ std::vector<int> not_one_indices;
+
+ // Separate into one indices and not one indices.
+ for (int i = 0; i < in_shape.size(); i++) {
+ if (in_shape[i] == 1) {
+ one_indices.push_back(i);
+ } else {
+ not_one_indices.push_back(i);
+ }
+ }
+
+ // Reorder the vertices.
+ std::vector<int> perm;
+ perm.reserve(in_shape.size());
+ int one_index = 0;
+ int not_one_index = 0;
+ for (const auto val : out_shape) {
+ if (val == 1) {
+ perm.push_back(one_indices[one_index]);
+ one_index++;
+ } else {
+ perm.push_back(not_one_indices[not_one_index]);
+ not_one_index++;
+ }
+ }
+
+ return perm;
+}
+
+} // namespace
+
+// When a transpose is fed into a reshape, it is possible for the two operators
+// to be merged if the reshape does not affect memory ordering and does not
+// affects the number of dimensions. This only occurs when only unary dimensions
+// are shifting position.
+bool MergeReshapeIntoPrecedingTranspose::Run(Model* model,
+ std::size_t op_index) {
+ auto it = model->operators.begin() + op_index;
+ auto* reshape_op = ConvertOperator<TensorFlowReshapeOperator*>(
+ it->get(), OperatorType::kTensorFlowReshape);
+
+ if (reshape_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, reshape_op) || reshape_op->shape.empty()) {
+ return false;
+ }
+
+ const string intermediate_name = reshape_op->inputs[0];
+ const string output_name = reshape_op->outputs[0];
+
+ // Guarantee the input is only consume by the reshape.
+ if (CountOpsWithInput(*model, intermediate_name) != 1) {
+ return false;
+ }
+
+ // Check for the parent operator.
+ const auto& transpose_it = FindOpWithOutput(*model, intermediate_name);
+ if (transpose_it == model->operators.end()) {
+ return false;
+ }
+
+ // Find the parent operator and guarantee it is a transpose.
+ TransposeOperator* transpose_op = ConvertOperator<TransposeOperator*>(
+ transpose_it->get(), OperatorType::kTranspose);
+
+ if (transpose_op == nullptr) {
+ return false;
+ }
+
+ if (!OperatorReady(*model, transpose_op) || transpose_op->perm.empty()) {
+ return false;
+ }
+
+ if (!ReshapeIsEquivalentToTranspose(*model, reshape_op,
+ false /*allow_extra_unary_dimensions*/)) {
+ return false;
+ }
+
+ // Check that the intermediate is not an output array.
+ if (!IsDiscardableArray(*model, intermediate_name)) {
+ AddMessageF(
+ "Cannot fuse %s and %s as it would invalidate the transpose "
+ "output array.",
+ LogName(*transpose_op), LogName(*reshape_op));
+ return false;
+ }
+
+ AddMessageF("Merging operations %s and %s", LogName(*transpose_op),
+ LogName(*reshape_op));
+
+ // const auto& intermediate_array = model->GetArray(intermediate_name);
+ // const auto& output_array = model->GetArray(output_name);
+
+ auto merged_perm = ReshapeToTranspose(*model, reshape_op);
+
+ // Combine the permutations.
+ const auto& transpose_perm = transpose_op->perm;
+ for (int i = 0; i < merged_perm.size(); i++) {
+ merged_perm[i] = transpose_perm[merged_perm[i]];
+ }
+
+ // Remove the reshape as passthrough operation.
+ if (!RemoveTrivialPassthroughOp(this, model, op_index)) {
+ return false;
+ }
+
+ // Update transpose_op's constant buffer to contain the new permutation.
+ model->GetArray(transpose_op->inputs[1])
+ .GetMutableBuffer<ArrayDataType::kInt32>()
+ .data = merged_perm;
+ transpose_op->perm = merged_perm;
+
+ // transpose_ops's shape will likely has changed.
+ model->GetArray(transpose_op->outputs[0]).clear_shape();
+
+ return true;
+}
+
+} // namespace toco