diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc index cf17c49b10..9c1ed2b732 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc @@ -26,20 +26,21 @@ limitations under the License. namespace toco { -bool PropagateActivationFunctionIntoConstants::Run(Model* model, - std::size_t op_index) { +::tensorflow::Status PropagateActivationFunctionIntoConstants::Run( + Model* model, std::size_t op_index, bool* modified) { + *modified = false; const auto ac_it = model->operators.begin() + op_index; const auto* ac_op = ac_it->get(); if (ac_op->type != OperatorType::kRelu6 && ac_op->type != OperatorType::kRelu1 && ac_op->type != OperatorType::kRelu) { - return false; + return ::tensorflow::Status::OK(); } // Find the op producing the array passed to this activation function. auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]); if (!src_op) { - return false; + return ::tensorflow::Status::OK(); } // Ensure the src_op is not used without the activation function applied. @@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, src_op_input = src_op->inputs[0]; break; default: - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]); @@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, "Not propagating activation function %s into %s:%s because it is not " "constant", LogName(*ac_op), LogName(*src_op), src_op_input); - return false; + return ::tensorflow::Status::OK(); } // Get the array we'll be working with and ensure it's a compatible type. @@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, "Not propagating activation function %s into %s:%s because it is " "non-float data", LogName(*ac_op), LogName(*src_op), src_op_input); - return false; + return ::tensorflow::Status::OK(); } auto& const_array_data = const_array.GetMutableBuffer<ArrayDataType::kFloat>().data; @@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model, } default: LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op); - return false; + return ::tensorflow::Status::OK(); } const_array_data[i] = new_value; } AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op), LogName(*src_op), src_op_input); - return RemoveTrivialPassthroughOp(this, model, op_index); + *modified = RemoveTrivialPassthroughOp(this, model, op_index); + return ::tensorflow::Status::OK(); } } // namespace toco |