aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
index 5364eebbc9..3034c1b1eb 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc
@@ -112,7 +112,10 @@ bool CopyMinMaxFromFirstInput(const Operator& op, Model* model) {
return true;
}
-bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantUnaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto unary_it = model->operators.begin() + op_index;
const auto* unary_op = unary_it->get();
// Test for unary ops of types that we know how to resolve.
@@ -133,28 +136,28 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
case OperatorType::kRelu:
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
// Check if the input is a constant parameter.
if (!IsConstantParameterArray(*model, unary_op->inputs[0])) {
- return false;
+ return ::tensorflow::Status::OK();
}
// if the unary op involves a tensor required by a rnn state, ignore it
for (const auto& rnn_state : model->flags.rnn_states()) {
if (unary_op->inputs[0] == rnn_state.back_edge_source_array()) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (unary_op->inputs[0] == rnn_state.state_array()) {
- return false;
+ return ::tensorflow::Status::OK();
}
}
auto& output_array = model->GetArray(unary_op->outputs[0]);
if (!output_array.has_shape()) {
// Yield until the output array dims have been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@@ -166,7 +169,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s "
" because it has a fused activation function",
LogName(*unary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// The min-max is only copied for ops that copy data without arithmetic.
@@ -187,7 +190,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
"Not resolving constant %s because we currently only support casting "
"to float",
LogName(*unary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
if (cast_op->src_data_type != input_array.buffer->type) {
AddMessageF(
@@ -197,7 +200,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
}
} else {
if (input_array.buffer->type != ArrayDataType::kFloat) {
- return false;
+ return ::tensorflow::Status::OK();
}
input_float_data = &(input_array.GetBuffer<ArrayDataType::kFloat>().data);
}
@@ -239,7 +242,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
CHECK_EQ(unary_op->inputs.size(), 2) << "Sum needs 2 inputs";
if (!IsConstantParameterArray(*model, unary_op->inputs[1])) {
AddMessageF("Axis input is non-constant");
- return false;
+ return ::tensorflow::Status::OK();
}
auto& axis_array = model->GetArray(unary_op->inputs[1]);
CHECK(axis_array.data_type == ArrayDataType::kInt32);
@@ -336,7 +339,7 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported activation function "
<< LogName(*unary_op);
- return false;
+ return ::tensorflow::Status::OK();
}
output_float_data[i] = new_value;
}
@@ -351,7 +354,8 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*unary_op));
model->operators.erase(unary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco