aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc12
1 files changed, 5 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
index 419a0776a6..b78efd7fc3 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_l2_normalization.cc
@@ -44,10 +44,9 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
const auto* div_or_mul_op = div_it->get();
OperatorType expected_op_type_producing_div_or_mul_input;
if (div_or_mul_op->type == OperatorType::kDiv) {
- expected_op_type_producing_div_or_mul_input = OperatorType::kTensorFlowSqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kSqrt;
} else if (div_or_mul_op->type == OperatorType::kMul) {
- expected_op_type_producing_div_or_mul_input =
- OperatorType::kTensorFlowRsqrt;
+ expected_op_type_producing_div_or_mul_input = OperatorType::kRsqrt;
} else {
return false;
}
@@ -75,8 +74,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* add_op = nullptr;
Operator* op_producing_add_input = nullptr;
if (op_producing_sqrt_or_rsqrt_input->type == OperatorType::kAdd ||
- op_producing_sqrt_or_rsqrt_input->type ==
- OperatorType::kTensorFlowMaximum) {
+ op_producing_sqrt_or_rsqrt_input->type == OperatorType::kMaximum) {
add_op = op_producing_sqrt_or_rsqrt_input;
bool add_can_be_removed = false;
CHECK_EQ(op_producing_sqrt_or_rsqrt_input->inputs.size(), 2);
@@ -113,7 +111,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
Operator* sum_op =
add_op ? op_producing_add_input : op_producing_sqrt_or_rsqrt_input;
- if (sum_op->type != OperatorType::kTensorFlowSum) {
+ if (sum_op->type != OperatorType::kSum) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Sum op, got %s",
@@ -122,7 +120,7 @@ bool IdentifyL2Normalization::Run(Model* model, std::size_t op_index) {
}
Operator* square_op = GetOpWithOutput(*model, sum_op->inputs[0]);
- if (square_op->type != OperatorType::kTensorFlowSquare) {
+ if (square_op->type != OperatorType::kSquare) {
AddMessageF(
"Giving up trying to identify L2Normalization subgraph: "
"expected Square op, got %s",