diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc index 8f2c1f8162..a79779f55d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_batch_normalization.cc @@ -25,10 +25,13 @@ limitations under the License. namespace toco { -bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveBatchNormalization::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto bn_it = model->operators.begin() + op_index; if (bn_it->get()->type != OperatorType::kBatchNormalization) { - return false; + return ::tensorflow::Status::OK(); } const auto* bn_op = static_cast<const BatchNormalizationOperator*>(bn_it->get()); @@ -53,7 +56,7 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { // so we need to exit early if these buffers don't exist (i.e. if the params // haven't yet been resolved as constants). if (!mean_array.buffer || !multiplier_array.buffer || !offset_array.buffer) { - return false; + return ::tensorflow::Status::OK(); } // Create the new Mul, Add operators @@ -142,7 +145,8 @@ bool ResolveBatchNormalization::Run(Model* model, std::size_t op_index) { DCHECK_EQ(bn_it->get(), bn_op); model->operators.erase(bn_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |