aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc19
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
index b90a156a0d..c11fee4dc9 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc
@@ -43,13 +43,15 @@ limitations under the License.
namespace toco {
-bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto add_op_it = model->operators.begin() + op_index;
const auto* add_op = add_op_it->get();
if (add_op == nullptr || add_op->type != OperatorType::kAdd ||
add_op->inputs.size() != 2 ||
add_op->fused_activation_function != FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]);
@@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_input_op->inputs.size() != 1 ||
relu_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
// TODO(ycling): Both Add and Mul are commutative. Support the case where
@@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (mul_op == nullptr || mul_op->type != OperatorType::kMul ||
mul_op->inputs.size() != 2 ||
mul_op->fused_activation_function != FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto neg_alpha_tensor_name = mul_op->inputs[0];
@@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
if (relu_neg_input_op == nullptr ||
relu_neg_input_op->inputs.size() != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
const Operator* final_input_op;
@@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
relu_neg_input_op->type != OperatorType::kRelu ||
relu_neg_input_op->fused_activation_function !=
FusedActivationFunctionType::kNone) {
- return false;
+ return ::tensorflow::Status::OK();
}
final_input_op = neg_input_op;
}
if (relu_input_op->inputs[0] != final_input_op->inputs[0]) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto input_tensor_name = relu_input_op->inputs[0];
@@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) {
// intermediate tensors aren't used by other ops, those will be removed by
// other graph transformation rules.
model->operators.erase(FindOp(*model, add_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco