diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 5364eebbc9..3034c1b1eb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) { return true; } -bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto unary_it = model->operators.begin() + op_index; const auto* unary_op = unary_it->get(); // Test for unary ops of types that we know how to resolve. @@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { case OperatorType::kRelu: break; default: - return false; + return ::tensorflow::Status::OK(); } // Check if the input is a constant parameter. if (!IsConstantParameterArray(*model, unary_op->inputs[0])) { - return false; + return ::tensorflow::Status::OK(); } // if the unary op involves a tensor required by a rnn state, ignore it for (const auto& rnn_state : model->flags.rnn_states()) { if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) { - return false; + return ::tensorflow::Status::OK(); } if (unary_op->inputs[0] == rnn_state.state_array()) { - return false; + return ::tensorflow::Status::OK(); } } auto& output_array = model->GetArray(unary_op->outputs[0]); if (!output_array.has_shape()) { // Yield until the output array dims have been resolved. - return false; + return ::tensorflow::Status::OK(); } // At the moment we don't want to care about fused activation functions. @@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { "Not resolving constant %s " " because it has a fused activation function", LogName(*unary_op)); - return false; + return ::tensorflow::Status::OK(); } // The min-max is only copied for ops that copy data without arithmetic. @@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { "Not resolving constant %s because we currently only support casting " "to float", LogName(*unary_op)); - return false; + return ::tensorflow::Status::OK(); } if (cast_op->src_data_type != input_array.buffer->type) { AddMessageF( @@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { } } else { if (input_array.buffer->type != ArrayDataType::kFloat) { - return false; + return ::tensorflow::Status::OK(); } input_float_data = &(input_array.GetBuffer<ArrayDataType::kFloat>().data); } @@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs"; if (!IsConstantParameterArray(*model, unary_op->inputs[1])) { AddMessageF("Axis input is non-constant"); - return false; + return ::tensorflow::Status::OK(); } auto& axis_array = model->GetArray(unary_op->inputs[1]); CHECK(axis_array.data_type == ArrayDataType::kInt32); @@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { default: LOG(FATAL) << "Unsupported activation function " << LogName(*unary_op); - return false; + return ::tensorflow::Status::OK(); } output_float_data[i] = new_value; } @@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*unary_op)); model->operators.erase(unary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |