aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index b78efd7fc3..78f60f52fb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -39,7 +39,10 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator(
}
} // namespace
-bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status IdentifyL2Normalization::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto div_it = model->operators.begin() + op_index;
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
@@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
} else if (div_or_mul_op->type == OperatorType::kMul) {
expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(div_or_mul_op->inputs.size(), 2);
Operator* op_producing_div_or_mul_input[2] = {
@@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
if (!op_producing_div_or_mul_input[1] ||
op_producing_div_or_mul_input[1]->type !=
expected_op_type_producing_div_or_mul_input) {
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1];
CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1);
Operator* op_producing_sqrt_or_rsqrt_input =
GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]);
if (!op_producing_sqrt_or_rsqrt_input) {
- return false;
+ return ::tensorflow::Status::OK();
}
// There may be an Add or a Maximum here, adding or clamping to a "small"
@@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
" because the operator producing the input to the square root, %s,"
", does not match the expected pattern",
LogName(*op_producing_sqrt_or_rsqrt_input));
- return false;
+ return ::tensorflow::Status::OK();
}
}
@@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
LogName(*sum_op));
- return false;
+ return ::tensorflow::Status::OK();
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
@@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",
LogName(*square_op));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(square_op->inputs.size(), 1);
@@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
"Giving up trying to identify L2Normalization subgraph: %s does not "
"take the same input as the Mul/Div node",
LogName(*square_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Create and emplace the new L2Normalization
@@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op));
model->EraseArray(div_or_mul_op->inputs[1]);
model->operators.erase(FindOperator(model, div_or_mul_op));
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco