diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc | 28 |
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc index b35c3e19c4..869dfae98e 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -86,11 +86,14 @@ bool Slice(SliceOperator const& op, Array const& input_array, } // namespace -bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantSlice::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kSlice) { - return false; + return ::tensorflow::Status::OK(); } const SliceOperator* op = static_cast<const SliceOperator*>(base_op); @@ -99,49 +102,49 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes. - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes. - return false; + return ::tensorflow::Status::OK(); } if (op->begin.empty() || op->size.empty()) { // Attributes have not resolved yet. - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); switch (output_array.data_type) { case ArrayDataType::kFloat: if (!Slice<ArrayDataType::kFloat>(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kUint8: if (!Slice<ArrayDataType::kUint8>(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt32: if (!Slice<ArrayDataType::kInt32>(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; case ArrayDataType::kInt64: if (!Slice<ArrayDataType::kInt64>(*op, input_array, &output_array)) { - return false; + return ::tensorflow::Status::OK(); } break; default: @@ -159,7 +162,8 @@ bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { // Erase the operator model->operators.erase(it); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |