aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc30
1 files changed, 17 insertions, 13 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index b2b2ea151b..ac94f45321 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -51,27 +51,30 @@ void FillArrayWithZeros(Array* array) {
// Removes a multiplication by array of constant zeros by making the output
// array an array of constant zeros and removing the input arrays if they are no
// longer needed.
-bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveMultiplyByZero::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto mul_it = model->operators.begin() + op_index;
auto* mul_op = mul_it->get();
if (mul_op->type != OperatorType::kMul) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
if (!IsDiscardableArray(*model, output_array_name)) {
- return false;
+ return ::tensorflow::Status::OK();
}
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
// Yield if the output shape is not known yet.
if (!output_array.has_shape()) {
- return false;
+ return ::tensorflow::Status::OK();
}
// This transformation only handles the case where one operand is all 0's and
@@ -83,12 +86,12 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
};
if (!is_input_constant[0] && !is_input_constant[1]) {
// Neither input is constant, so nothing we can resolve here.
- return false;
+ return ::tensorflow::Status::OK();
}
if (is_input_constant[0] && is_input_constant[1]) {
// Both inputs are constants. That's a job for constants propagation, not
// for us to handle here.
- return false;
+ return ::tensorflow::Status::OK();
}
const int index_of_constant_input = is_input_constant[0] ? 0 : 1;
const int index_of_variable_input = is_input_constant[0] ? 1 : 0;
@@ -105,7 +108,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kFloat>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kFloat>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kFloat>(&output_array);
} break;
@@ -114,7 +117,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kUint8>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kUint8>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kUint8>(&output_array);
} break;
@@ -123,7 +126,7 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt32>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt32>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt32>(&output_array);
} break;
@@ -132,14 +135,14 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
constant_input_array.GetBuffer<ArrayDataType::kInt64>().data;
if (!AreAllBufferElementsZero<DataType<ArrayDataType::kInt64>>(
constant_input_data)) {
- return false;
+ return ::tensorflow::Status::OK();
}
FillArrayWithZeros<ArrayDataType::kInt64>(&output_array);
} break;
default:
AddMessageF(
"Cannot resolve multiply by 0 because of unsupported data type\n");
- return false;
+ return ::tensorflow::Status::OK();
}
// Erase input arrays to the multiply if no longer used
@@ -149,7 +152,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
// Erase the multiply operator.
model->operators.erase(mul_it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco