aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc20
1 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
index cf17c49b10..9c1ed2b732 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
@@ -26,20 +26,21 @@ limitations under the License.
namespace toco {
-bool PropagateActivationFunctionIntoConstants::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status PropagateActivationFunctionIntoConstants::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function.
auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]);
if (!src_op) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Ensure the src_op is not used without the activation function applied.
@@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
src_op_input = src_op->inputs[0];
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]);
@@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is not "
"constant",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the array we'll be working with and ensure it's a compatible type.
@@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is "
"non-float data",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
auto& const_array_data =
const_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
@@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
}
default:
LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op);
- return false;
+ return ::tensorflow::Status::OK();
}
const_array_data[i] = new_value;
}
AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op),
LogName(*src_op), src_op_input);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco