aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc30
1 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
index 7f44c65285..f0d8d924ad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc
@@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) {
//
// Note we are testing for one particular case of a broader set of possible
// binary-reshape op transformations. This transformation could be generalized.
-bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
Operator* binary_op = binary_it->get();
if (binary_op->type != OperatorType::kAdd &&
@@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
- return false;
+ return ::tensorflow::Status::OK();
}
// BINARY OP INPUT CHECKS
@@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (!input_is_const[0] && !input_is_const[1]) {
// To limit our scope, we require one constant input. Though there's no
// reason this transformation wouldn't work with all variable inputs.
- return false;
+ return ::tensorflow::Status::OK();
}
if (input_is_const[0] && input_is_const[1]) {
// Both inputs are constants. Leave this for constants propagation.
- return false;
+ return ::tensorflow::Status::OK();
}
const int constant_input_idx = input_is_const[0] ? 0 : 1;
const int variable_input_idx = input_is_const[0] ? 1 : 0;
@@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because it's non-constant input shape is not resolved.",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
// Constant array shape must be the latter part of the variable shape.
- return false;
+ return ::tensorflow::Status::OK();
}
// RESHAPE OP CHECKS
@@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
if (reshape_it == model->operators.end()) {
AddMessageF("Not moving %s because it's variable input is not connected.",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* reshape_op = reshape_it->get();
if (reshape_op->type != OperatorType::kReshape) {
AddMessageF("Not moving %s because the preceding %s is not a reshape op",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
if (!reshape_input_array.has_shape()) {
@@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because it's non-constant input shape is not resolved "
"yet",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsTailOfShape(
model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
model->GetArray(reshape_op->outputs[0]).shape())) {
// Constant array shape must be the latter part of the binary op output
// shape.
- return false;
+ return ::tensorflow::Status::OK();
}
// EXTRA CHECKS ON CONNECTING ARRAY
@@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not moving %s because the output of reshape op %s is an output op.",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
}
int count_ops_consuming_output =
@@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
"Not moving %s because the output of reshape op %s is consumed by "
"another op",
LogName(*binary_op), LogName(*reshape_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// SWAP ORDER OF BINARY AND RESHAPE OPS
@@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) {
// Clear binary output shape so it will be re-propagated
model->GetArray(binary_op->outputs[0]).clear_shape();
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco