aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc20
1 files changed, 11 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
index cf17c49b10..9c1ed2b732 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_activation_function_into_constants.cc
@@ -26,20 +26,21 @@ limitations under the License.
namespace toco {
-bool PropagateActivationFunctionIntoConstants::Run(Model* model,
- std::size_t op_index) {
+::tensorflow::Status PropagateActivationFunctionIntoConstants::Run(
+ Model* model, std::size_t op_index, bool* modified) {
+ *modified = false;
const auto ac_it = model->operators.begin() + op_index;
const auto* ac_op = ac_it->get();
if (ac_op->type != OperatorType::kRelu6 &&
ac_op->type != OperatorType::kRelu1 &&
ac_op->type != OperatorType::kRelu) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Find the op producing the array passed to this activation function.
auto* src_op = GetOpWithOutput(*model, ac_op->inputs[0]);
if (!src_op) {
- return false;
+ return ::tensorflow::Status::OK();
}
// Ensure the src_op is not used without the activation function applied.
@@ -57,7 +58,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
src_op_input = src_op->inputs[0];
break;
default:
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(src_op->outputs[0], ac_op->inputs[0]);
@@ -69,7 +70,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is not "
"constant",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
// Get the array we'll be working with and ensure it's a compatible type.
@@ -79,7 +80,7 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
"Not propagating activation function %s into %s:%s because it is "
"non-float data",
LogName(*ac_op), LogName(*src_op), src_op_input);
- return false;
+ return ::tensorflow::Status::OK();
}
auto& const_array_data =
const_array.GetMutableBuffer<ArrayDataType::kFloat>().data;
@@ -108,14 +109,15 @@ bool PropagateActivationFunctionIntoConstants::Run(Model* model,
}
default:
LOG(FATAL) << "Unsupported activation function " << LogName(*ac_op);
- return false;
+ return ::tensorflow::Status::OK();
}
const_array_data[i] = new_value;
}
AddMessageF("Propagated activation function %s into %s:%s", LogName(*ac_op),
LogName(*src_op), src_op_input);
- return RemoveTrivialPassthroughOp(this, model, op_index);
+ *modified = RemoveTrivialPassthroughOp(this, model, op_index);
+ return ::tensorflow::Status::OK();
}
} // namespace toco