diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc index b90a156a0d..c11fee4dc9 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_prelu.cc @@ -43,13 +43,15 @@ limitations under the License. namespace toco { -bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyPRelu::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; const auto add_op_it = model->operators.begin() + op_index; const auto* add_op = add_op_it->get(); if (add_op == nullptr || add_op->type != OperatorType::kAdd || add_op->inputs.size() != 2 || add_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } const auto* relu_input_op = GetOpWithOutput(*model, add_op->inputs[0]); @@ -57,7 +59,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { relu_input_op->inputs.size() != 1 || relu_input_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } // TODO(ycling): Both Add and Mul are commutative. Support the case where @@ -66,7 +68,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { if (mul_op == nullptr || mul_op->type != OperatorType::kMul || mul_op->inputs.size() != 2 || mul_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } const auto neg_alpha_tensor_name = mul_op->inputs[0]; @@ -75,7 +77,7 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { if (relu_neg_input_op == nullptr || relu_neg_input_op->inputs.size() != 1) { - return false; + return ::tensorflow::Status::OK(); } const Operator* final_input_op; @@ -92,13 +94,13 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { relu_neg_input_op->type != OperatorType::kRelu || relu_neg_input_op->fused_activation_function != FusedActivationFunctionType::kNone) { - return false; + return ::tensorflow::Status::OK(); } final_input_op = neg_input_op; } if (relu_input_op->inputs[0] != final_input_op->inputs[0]) { - return false; + return ::tensorflow::Status::OK(); } const auto input_tensor_name = relu_input_op->inputs[0]; @@ -128,7 +130,8 @@ bool IdentifyPRelu::Run(Model* model, std::size_t op_index) { // intermediate tensors aren't used by other ops, those will be removed by // other graph transformation rules. model->operators.erase(FindOp(*model, add_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |