aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
index c5ce3fcd95..88511a7d3c 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/fuse_activation_functions.cc
@@ -25,27 +25,30 @@ limitations under the License.
namespace toco {
-bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status FuseActivationFunctions::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function
Operator* op = GetOpWithOutput(*model, ac_op->inputs[0]);
- if (!op) return false;
+ if (!op) return ::tensorflow::Status::OK();
if (CountTrueOutputs(*model, *op) > 1) {
AddMessageF(
"Not fusing activation function %s into %s because it has more than "
"one consumed output",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->outputs[0], ac_op->inputs[0]);
@@ -57,7 +60,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function into %s because it is consumed by more "
"than 1 other operator",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsDiscardableArray(*model, op->outputs[0])) {
@@ -65,7 +68,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because output %s it is not "
"discardable",
LogName(*ac_op), LogName(*op), op->outputs[0]);
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->fused_activation_function != FusedActivationFunctionType::kNone) {
@@ -73,7 +76,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s into %s because it already has a "
"fused activation function",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (!OperatorSupportsFusedActivation(op->type)) {
@@ -81,7 +84,7 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
"Not fusing activation function %s because the %s op doesn't support "
"it",
LogName(*ac_op), LogName(*op));
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Fusing activation function %s into the preceding %s",
@@ -98,7 +101,8 @@ bool FuseActivationFunctions::Run(Model* model, std::size_t op_index) {
model->EraseArray(ac_op->inputs[0]);
op->outputs[0] = ac_op->outputs[0];
model->operators.erase(ac_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco