diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc index b2b2ea151b..ac94f45321 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) { // Removes a multiplication by array of constant zeros by making the output // array an array of constant zeros and removing the input arrays if they are no // longer needed. -bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveMultiplyByZero::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto mul_it = model->operators.begin() + op_index; auto* mul_op = mul_it->get(); if (mul_op->type != OperatorType::kMul) { - return false; + return ::tensorflow::Status::OK(); } const auto& output_array_name = mul_op->outputs[0]; auto& output_array = model->GetArray(output_array_name); if (!IsDiscardableArray(*model, output_array_name)) { - return false; + return ::tensorflow::Status::OK(); } if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } // Yield if the output shape is not known yet. if (!output_array.has_shape()) { - return false; + return ::tensorflow::Status::OK(); } // This transformation only handles the case where one operand is all 0's and @@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { }; if (!is_input_constant[0] && !is_input_constant[1]) { // Neither input is constant, so nothing we can resolve here. - return false; + return ::tensorflow::Status::OK(); } if (is_input_constant[0] && is_input_constant[1]) { // Both inputs are constants. That's a job for constants propagation, not // for us to handle here. - return false; + return ::tensorflow::Status::OK(); } const int index_of_constant_input = is_input_constant[0] ? 0 : 1; const int index_of_variable_input = is_input_constant[0] ? 1 : 0; @@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer<ArrayDataType::kFloat>().data; if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros<ArrayDataType::kFloat>(&output_array); } break; @@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer<ArrayDataType::kUint8>().data; if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros<ArrayDataType::kUint8>(&output_array); } break; @@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer<ArrayDataType::kInt32>().data; if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros<ArrayDataType::kInt32>(&output_array); } break; @@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { constant_input_array.GetBuffer<ArrayDataType::kInt64>().data; if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>( constant_input_data)) { - return false; + return ::tensorflow::Status::OK(); } FillArrayWithZeros<ArrayDataType::kInt64>(&output_array); } break; default: AddMessageF( "Cannot resolve multiply by 0 because of unsupported data type\n"); - return false; + return ::tensorflow::Status::OK(); } // Erase input arrays to the multiply if no longer used @@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { // Erase the multiply operator. model->operators.erase(mul_it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |