aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-10-09 11:38:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:48:46 -0700
commit12e164d1e7c0b197f06d5d3c2ed26318b89b5e4c (patch)
treed2f0b6ba463baff8e3607575f41d3655762f3d14 /tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
parent931353c5f79c2d419afb3a5ecac59184c5558351 (diff)
Return ::tensorflow::Status in Toco Graph Transformations.
PiperOrigin-RevId: 216392908
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
index a6f665b5f0..fccecef600 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_reshape.cc
@@ -22,11 +22,14 @@ limitations under the License.
namespace toco {
// Resolves a constant reshape operation by copying the buffer.
-bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantReshape::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kReshape) {
- return false;
+ return ::tensorflow::Status::OK();
}
const auto* op = static_cast<const TensorFlowReshapeOperator*>(base_op);
@@ -36,17 +39,17 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// We require constant inputs.
if (!IsConstantParameterArray(*model, op->inputs[0]) ||
!IsConstantParameterArray(*model, op->inputs[1])) {
- return false;
+ return ::tensorflow::Status::OK();
}
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes.
- return false;
+ return ::tensorflow::Status::OK();
}
const Array& input_array = model->GetArray(op->inputs[0]);
@@ -54,7 +57,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
AddMessageF("Constant reshape is non-trivial (%s -> %s)",
ShapeToString(input_array.shape()),
ShapeToString(output_array.shape()));
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@@ -95,7 +98,7 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
default:
LOG(FATAL) << "Unsupported data type: "
<< ArrayDataTypeName(input_array.data_type);
- return false;
+ return ::tensorflow::Status::OK();
}
AddMessageF("Resolving constant reshape of %s", LogName(*op));
@@ -112,7 +115,8 @@ bool ResolveConstantReshape::Run(Model* model, std::size_t op_index) {
// Erase the operator.
model->operators.erase(it);
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco