diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc | 17 |
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index 94820a0166..51d0629362 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, } } // namespace -bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index, + bool* modified) { + *modified = false; // Follow sequences of min+max and max+min. First get the leading op. const auto op_it = model->operators.begin() + op_index; const auto* op_0 = op_it->get(); if (op_0->type != OperatorType::kMinimum && op_0->type != OperatorType::kMaximum) { - return false; + return ::tensorflow::Status::OK(); } // Get the paired op and ensure it's the counter to the first. @@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { (op_1->type != OperatorType::kMinimum && op_1->type != OperatorType::kMaximum) || op_0->type == op_1->type) { - return false; + return ::tensorflow::Status::OK(); } const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1; const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1; if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) { - return false; + return ::tensorflow::Status::OK(); } if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) { - return false; + return ::tensorflow::Status::OK(); } // Get the original input to the min+max pair. @@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { int max_scalar_input_index = GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f); if (min_scalar_input_index == -1 || max_scalar_input_index == -1) { - return false; + return ::tensorflow::Status::OK(); } int op_0_scalar_input_index = op_0 == min_op ? min_scalar_input_index : max_scalar_input_index; @@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, op_0)); model->operators.erase(FindOperator(model, op_1)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |