diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-10-09 11:38:15 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 11:48:46 -0700 |
commit | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch) | |
tree | d2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc | |
parent | 931353c5f79c2d419afb3a5ecac59184c5558351 (diff) |
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc index b78efd7fc3..78f60f52fb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc @@ -39,7 +39,10 @@ std::vector<std::unique_ptr<Operator>>::iterator FindOperator( } } // namespace -bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { +::tensorflow::Status IdentifyL2Normalization::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto div_it = model->operators.begin() + op_index; const auto* div_or_mul_op = div_it->get(); OperatorType expected_op_type_producing_div_or_mul_input; @@ -48,7 +51,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { } else if (div_or_mul_op->type == OperatorType::kMul) { expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt; } else { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(div_or_mul_op->inputs.size(), 2); Operator* op_producing_div_or_mul_input[2] = { @@ -58,14 +61,14 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { if (!op_producing_div_or_mul_input[1] || op_producing_div_or_mul_input[1]->type != expected_op_type_producing_div_or_mul_input) { - return false; + return ::tensorflow::Status::OK(); } Operator* sqrt_or_rsqrt_op = op_producing_div_or_mul_input[1]; CHECK_EQ(sqrt_or_rsqrt_op->inputs.size(), 1); Operator* op_producing_sqrt_or_rsqrt_input = GetOpWithOutput(*model, sqrt_or_rsqrt_op->inputs[0]); if (!op_producing_sqrt_or_rsqrt_input) { - return false; + return ::tensorflow::Status::OK(); } // There may be an Add or a Maximum here, adding or clamping to a "small" @@ -105,7 +108,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { " because the operator producing the input to the square root, %s," ", does not match the expected pattern", LogName(*op_producing_sqrt_or_rsqrt_input)); - return false; + return ::tensorflow::Status::OK(); } } @@ -116,7 +119,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: " "expected Sum op, got %s", LogName(*sum_op)); - return false; + return ::tensorflow::Status::OK(); } Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]); @@ -125,7 +128,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: " "expected Square op, got %s", LogName(*square_op)); - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(square_op->inputs.size(), 1); @@ -135,7 +138,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { "Giving up trying to identify L2Normalization subgraph: %s does not " "take the same input as the Mul/Div node", LogName(*square_op)); - return false; + return ::tensorflow::Status::OK(); } // Create and emplace the new L2Normalization @@ -162,7 +165,8 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) { model->operators.erase(FindOperator(model, sqrt_or_rsqrt_op)); model->EraseArray(div_or_mul_op->inputs[1]); model->operators.erase(FindOperator(model, div_or_mul_op)); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |