aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc17
1 files changed, 10 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
index 94820a0166..51d0629362 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc
@@ -56,13 +56,15 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op,
}
} // namespace
-bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyRelu1::Run(Model* model, std::size_t op_index,
+ bool* modified) {
+ *modified = false;
// Follow sequences of min+max and max+min. First get the leading op.
const auto op_it = model->operators.begin() + op_index;
const auto* op_0 = op_it->get();
if (op_0->type != OperatorType::kMinimum &&
op_0->type != OperatorType::kMaximum) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the paired op and ensure it's the counter to the first.
@@ -71,17 +73,17 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
(op_1->type != OperatorType::kMinimum &&
op_1->type != OperatorType::kMaximum) ||
op_0->type == op_1->type) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* min_op = op_0->type == OperatorType::kMinimum ? op_0 : op_1;
const auto* max_op = op_0->type == OperatorType::kMaximum ? op_0 : op_1;
if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the original input to the min+max pair.
@@ -90,7 +92,7 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
int max_scalar_input_index =
GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f);
if (min_scalar_input_index == -1 || max_scalar_input_index == -1) {
- return false;
+ return ::tensorflow::Status::OK();
}
int op_0_scalar_input_index =
op_0 == min_op ? min_scalar_input_index : max_scalar_input_index;
@@ -111,7 +113,8 @@ bool IdentifyRelu1::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, op_0));
model->operators.erase(FindOperator(model, op_1));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco