aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc178
1 files changed, 178 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
new file mode 100644
index 0000000000..7f44c65285
--- /dev/null
+++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -0,0 +1,178 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+ ==============================================================================*/
+#include <algorithm>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+
+namespace toco {
+
+namespace {
+
+bool IsTailOfShape(const Shape& tail, const Shape& shape) {
+ // Return true if 'tail' dimensions are the same as the ending dimensions of
+ // 'shape'.
+
+ int shape_end = shape.dimensions_count() - 1;
+ int tail_end = tail.dimensions_count() - 1;
+
+ if (tail_end > shape_end) {
+ // tail cannot be longer than shape.
+ return false;
+ }
+
+ // Walk dimensions back to front and compare
+ for (int i = 0; i <= tail_end; i++) {
+ if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+} // namespace
+
+// If a binary operator is doing a broadcast operation from a constant array,
+// and the constant array shape is the tail of both the other input shape, and a
+// subsequent reshape op's output shape, we can swap their order. Since we
+// prefer to have reshape ops after mathematic ops, this can allow for the
+// collapsing of some reshapes. The WaveNet model in particular benefits from
+// this transformation.
+//
+// Note we are testing for one particular case of a broader set of possible
+// binary-reshape op transformations. This transformation could be generalized.
+bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
+ const auto binary_it = model->operators.begin() + op_index;
+ Operator* binary_op = binary_it->get();
+ if (binary_op->type != OperatorType::kAdd &&
+ binary_op->type != OperatorType::kMul &&
+ binary_op->type != OperatorType::kSub &&
+ binary_op->type != OperatorType::kDiv &&
+ binary_op->type != OperatorType::kFloorDiv &&
+ binary_op->type != OperatorType::kFloorMod &&
+ binary_op->type != OperatorType::kMinimum &&
+ binary_op->type != OperatorType::kMaximum &&
+ binary_op->type != OperatorType::kLess &&
+ binary_op->type != OperatorType::kLessEqual &&
+ binary_op->type != OperatorType::kGreater &&
+ binary_op->type != OperatorType::kGreaterEqual) {
+ return false;
+ }
+
+ // BINARY OP INPUT CHECKS
+ CHECK_EQ(binary_op->inputs.size(), 2);
+ const bool input_is_const[2] = {
+ IsConstantParameterArray(*model, binary_op->inputs[0]),
+ IsConstantParameterArray(*model, binary_op->inputs[1]),
+ };
+ if (!input_is_const[0] && !input_is_const[1]) {
+ // To limit our scope, we require one constant input. Though there's no
+ // reason this transformation wouldn't work with all variable inputs.
+ return false;
+ }
+ if (input_is_const[0] && input_is_const[1]) {
+ // Both inputs are constants. Leave this for constants propagation.
+ return false;
+ }
+ const int constant_input_idx = input_is_const[0] ? 0 : 1;
+ const int variable_input_idx = input_is_const[0] ? 1 : 0;
+ CHECK(input_is_const[constant_input_idx]);
+ CHECK(!input_is_const[variable_input_idx]);
+
+ const auto& variable_input_array =
+ model->GetArray(binary_op->inputs[variable_input_idx]);
+ if (!variable_input_array.has_shape()) {
+ AddMessageF(
+ "Not moving %s because it's non-constant input shape is not resolved.",
+ LogName(*binary_op));
+ return false;
+ }
+ if (!IsTailOfShape(
+ model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
+ model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
+ // Constant array shape must be the latter part of the variable shape.
+ return false;
+ }
+
+ // RESHAPE OP CHECKS
+ auto reshape_it =
+ FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]);
+ if (reshape_it == model->operators.end()) {
+ AddMessageF("Not moving %s because it's variable input is not connected.",
+ LogName(*binary_op));
+ return false;
+ }
+ Operator* reshape_op = reshape_it->get();
+ if (reshape_op->type != OperatorType::kReshape) {
+ AddMessageF("Not moving %s because the preceding %s is not a reshape op",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+ const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
+ if (!reshape_input_array.has_shape()) {
+ AddMessageF(
+ "Not moving %s because it's non-constant input shape is not resolved "
+ "yet",
+ LogName(*binary_op));
+ return false;
+ }
+ if (!IsTailOfShape(
+ model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
+ model->GetArray(reshape_op->outputs[0]).shape())) {
+ // Constant array shape must be the latter part of the binary op output
+ // shape.
+ return false;
+ }
+
+ // EXTRA CHECKS ON CONNECTING ARRAY
+ for (const string& output_array : model->flags.output_arrays()) {
+ if (binary_op->inputs[variable_input_idx] == output_array) {
+ AddMessageF(
+ "Not moving %s because the output of reshape op %s is an output op.",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+ }
+ int count_ops_consuming_output =
+ CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]);
+ DCHECK_GE(count_ops_consuming_output, 1);
+ if (count_ops_consuming_output > 1) {
+ AddMessageF(
+ "Not moving %s because the output of reshape op %s is consumed by "
+ "another op",
+ LogName(*binary_op), LogName(*reshape_op));
+ return false;
+ }
+
+ // SWAP ORDER OF BINARY AND RESHAPE OPS
+ AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op),
+ LogName(*reshape_op));
+
+ // Swap op input and outputs
+ std::iter_swap(reshape_op->inputs.begin(),
+ binary_op->inputs.begin() + variable_input_idx);
+ std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin());
+
+ // Swap operator ordering
+ std::iter_swap(binary_it, reshape_it);
+
+ // Clear binary output shape so it will be re-propagated
+ model->GetArray(binary_op->outputs[0]).clear_shape();
+
+ return true;
+}
+
+} // namespace toco