diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc index 7f44c65285..f0d8d924ad 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/move_binary_operator_before_reshape.cc @@ -54,7 +54,10 @@ bool IsTailOfShape(const Shape& tail, const Shape& shape) { // // Note we are testing for one particular case of a broader set of possible // binary-reshape op transformations. This transformation could be generalized. -bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; Operator* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && @@ -69,7 +72,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kLessEqual && binary_op->type != OperatorType::kGreater && binary_op->type != OperatorType::kGreaterEqual) { - return false; + return ::tensorflow::Status::OK(); } // BINARY OP INPUT CHECKS @@ -81,11 +84,11 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { if (!input_is_const[0] && !input_is_const[1]) { // To limit our scope, we require one constant input. Though there's no // reason this transformation wouldn't work with all variable inputs. - return false; + return ::tensorflow::Status::OK(); } if (input_is_const[0] && input_is_const[1]) { // Both inputs are constants. Leave this for constants propagation. - return false; + return ::tensorflow::Status::OK(); } const int constant_input_idx = input_is_const[0] ? 0 : 1; const int variable_input_idx = input_is_const[0] ? 1 : 0; @@ -98,13 +101,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { AddMessageF( "Not moving %s because it's non-constant input shape is not resolved.", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsTailOfShape( model->GetArray(binary_op->inputs[constant_input_idx]).shape(), model->GetArray(binary_op->inputs[variable_input_idx]).shape())) { // Constant array shape must be the latter part of the variable shape. - return false; + return ::tensorflow::Status::OK(); } // RESHAPE OP CHECKS @@ -113,13 +116,13 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { if (reshape_it == model->operators.end()) { AddMessageF("Not moving %s because it's variable input is not connected.", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* reshape_op = reshape_it->get(); if (reshape_op->type != OperatorType::kReshape) { AddMessageF("Not moving %s because the preceding %s is not a reshape op", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]); if (!reshape_input_array.has_shape()) { @@ -127,14 +130,14 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { "Not moving %s because it's non-constant input shape is not resolved " "yet", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsTailOfShape( model->GetArray(binary_op->inputs[constant_input_idx]).shape(), model->GetArray(reshape_op->outputs[0]).shape())) { // Constant array shape must be the latter part of the binary op output // shape. - return false; + return ::tensorflow::Status::OK(); } // EXTRA CHECKS ON CONNECTING ARRAY @@ -143,7 +146,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { AddMessageF( "Not moving %s because the output of reshape op %s is an output op.", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } } int count_ops_consuming_output = @@ -154,7 +157,7 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { "Not moving %s because the output of reshape op %s is consumed by " "another op", LogName(*binary_op), LogName(*reshape_op)); - return false; + return ::tensorflow::Status::OK(); } // SWAP ORDER OF BINARY AND RESHAPE OPS @@ -172,7 +175,8 @@ bool MoveBinaryOperatorBeforeReshape::Run(Model* model, std::size_t op_index) { // Clear binary output shape so it will be re-propagated model->GetArray(binary_op->outputs[0]).clear_shape(); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |