diff options
author | 2018-10-09 11:38:15 -0700 | |
---|---|---|
committer | 2018-10-09 11:48:46 -0700 | |
commit | 12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch) | |
tree | d2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc | |
parent | 931353c5f79c2d419afb3a5ecac59184c5558351 (diff) |
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc index a6f665b5f0..fccecef600 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc @@ -22,11 +22,14 @@ limitations under the License. namespace toco { // Resolves a constant reshape operation by copying the buffer. -bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantReshape::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kReshape) { - return false; + return ::tensorflow::Status::OK(); } const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op); @@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { // We require constant inputs. if (!IsConstantParameterArray(*model, op->inputs[0]) || !IsConstantParameterArray(*model, op->inputs[1])) { - return false; + return ::tensorflow::Status::OK(); } auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } const Array& input_array = model->GetArray(op->inputs[0]); @@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { AddMessageF("Constant reshape is non-trivial (%s -> %s)", ShapeToString(input_array.shape()), ShapeToString(output_array.shape())); - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); @@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { default: LOG(FATAL) << "Unsupported data type: " << ArrayDataTypeName(input_array.data_type); - return false; + return ::tensorflow::Status::OK(); } AddMessageF("Resolving constant reshape of %s", LogName(*op)); @@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) { // Erase the operator. model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |