aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index f7e5aa6609..586f546a30 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model,
}
} // namespace
-bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto binary_it = model->operators.begin() + op_index;
const auto* binary_op = binary_it->get();
// Test for binary ops of types that we know how to resolve
@@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
binary_op->type != OperatorType::kLessEqual &&
binary_op->type != OperatorType::kGreater &&
binary_op->type != OperatorType::kGreaterEqual) {
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(binary_op->inputs.size(), 2);
@@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
const auto& input1_array = model->GetArray(binary_op->inputs[1]);
// Check if both inputs are constant parameters.
if (!input0_array.buffer || !input1_array.buffer) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(binary_op->outputs[0]);
// Yield until the output array dims have been resolved.
if (!output_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// At the moment we don't want to care about fused activation functions.
@@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF(
"Not resolving constant %s because it has a fused activation function",
LogName(*binary_op));
- return false;
+ return ::tensorflow::Status::OK();
}
// Check that input data types agree.
@@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
AddMessageF("Resolved constant %s to the equivalent constant array",
LogName(*binary_op));
model->operators.erase(binary_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco