diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc index c5ce3fcd95..88511a7d3c 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc @@ -25,27 +25,30 @@ limitations under the License. namespace toco { -bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { +::tensorflow::Status FuseActivationFunctions::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); if (ac_op->type != OperatorType::kRelu6 && ac_op->type != OperatorType::kRelu1 && ac_op->type != OperatorType::kRelu) { - return false; + return ::tensorflow::Status::OK(); } // Find the op producing the array passed to this activation function Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]); - if (!op) return false; + if (!op) return ::tensorflow::Status::OK(); if (CountTrueOutputs(*model, *op) > 1) { AddMessageF( "Not fusing activation function %s into %s because it has more than " "one consumed output", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(op->outputs[0], ac_op->inputs[0]); @@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function into %s because it is consumed by more " "than 1 other operator", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } if (!IsDiscardableArray(*model, op->outputs[0])) { @@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s into %s because output %s it is not " "discardable", LogName(*ac_op), LogName(*op), op->outputs[0]); - return false; + return ::tensorflow::Status::OK(); } if (op->fused_activation_function != FusedActivationFunctionType::kNone) { @@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s into %s because it already has a " "fused activation function", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } if (!OperatorSupportsFusedActivation(op->type)) { @@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { "Not fusing activation function %s because the %s op doesn't support " "it", LogName(*ac_op), LogName(*op)); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Fusing activation function %s into the preceding %s", @@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) { model->EraseArray(ac_op->inputs[0]); op->outputs[0] = ac_op->outputs[0]; model->operators.erase(ac_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |