diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc index dcbbead517..0de22b8ff4 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_binary_into_following_affine.cc @@ -150,14 +150,17 @@ void FuseMulOrDivParamsIntoFollowingAffine(Model* model, Operator* following_op, } // namespace -bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseBinaryIntoFollowingAffine::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; auto* binary_op = binary_it->get(); if (binary_op->type != OperatorType::kAdd && binary_op->type != OperatorType::kMul && binary_op->type != OperatorType::kSub && binary_op->type != OperatorType::kDiv) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -175,12 +178,12 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can fuse into a constant. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants // propagation, not for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -192,7 +195,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { if (index_of_constant_input != 1) { AddMessageF("Not fusing %s because the denominator is not constant", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -204,7 +207,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s into the following affine op, because we only know " "how to do so when the constant operand is a scalar", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } } @@ -212,7 +215,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { FusedActivationFunctionType::kNone) { AddMessageF("Not fusing %s because it has a fused activation function", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* following_op = GetOpWithInput(*model, binary_op->outputs[0]); @@ -221,7 +224,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because it is not consumed by exactly one other op", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } if (following_op->type != OperatorType::kConv && @@ -231,14 +234,14 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the following %s is not of one of the supported " "types", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } if (following_op->inputs.size() < 3) { AddMessageF( "Not fusing %s because the following %s does not have a bias vector", LogName(*following_op), LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } const auto& weights = model->GetArray(following_op->inputs[1]); @@ -248,7 +251,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { "Not fusing %s because the following %s has non-constant weights or " "bias arrays", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } // Try to fuse the binary params into the following op's params @@ -260,7 +263,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because the following %s does not use VALID padding", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } } if (following_op->type == OperatorType::kDepthwiseConv) { @@ -269,7 +272,7 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { AddMessageF( "Not fusing %s because the following %s does not use VALID padding", LogName(*binary_op), LogName(*following_op)); - return false; + return ::tensorflow::Status::OK(); } } FuseAddOrSubParamsIntoFollowingAffine(model, following_op, binary_op, @@ -294,7 +297,8 @@ bool FuseBinaryIntoFollowingAffine::Run(Model* model, std::size_t op_index) { model->EraseArray(old_constant_param_name); } model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |