diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc index f7e5aa6609..586f546a30 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc @@ -188,7 +188,10 @@ void EvaluateBinaryOperatorOnConstantInputs(Model* model, } } // namespace -bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantBinaryOperator::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto binary_it = model->operators.begin() + op_index; const auto* binary_op = binary_it->get(); // Test for binary ops of types that we know how to resolve @@ -204,7 +207,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { binary_op->type != OperatorType::kLessEqual && binary_op->type != OperatorType::kGreater && binary_op->type != OperatorType::kGreaterEqual) { - return false; + return ::tensorflow::Status::OK(); } CHECK_EQ(binary_op->inputs.size(), 2); @@ -212,13 +215,13 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { const auto& input1_array = model->GetArray(binary_op->inputs[1]); // Check if both inputs are constant parameters. if (!input0_array.buffer || !input1_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } auto& output_array = model->GetArray(binary_op->outputs[0]); // Yield until the output array dims have been resolved. if (!output_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // At the moment we don't want to care about fused activation functions. @@ -229,7 +232,7 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF( "Not resolving constant %s because it has a fused activation function", LogName(*binary_op)); - return false; + return ::tensorflow::Status::OK(); } // Check that input data types agree. @@ -253,7 +256,8 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) { AddMessageF("Resolved constant %s to the equivalent constant array", LogName(*binary_op)); model->operators.erase(binary_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |