diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 8853ed87e6..99c5a64662 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, } // anonymous namespace -bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { +::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model, + std::size_t op_index, + bool* modified) { + *modified = false; const auto it = model->operators.begin() + op_index; const auto* base_op = it->get(); if (base_op->type != OperatorType::kStridedSlice) { - return false; + return ::tensorflow::Status::OK(); } const StridedSliceOperator* op = @@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { auto& output_array = model->GetArray(op->outputs[0]); if (output_array.data_type == ArrayDataType::kNone) { // Yield until the output type has been set by PropagateArrayDataTypes - return false; + return ::tensorflow::Status::OK(); } if (!output_array.has_shape()) { // Yield until the output shape has been set by PropagateFixedShapes - return false; + return ::tensorflow::Status::OK(); } if (op->start_indices.empty() || op->stop_indices.empty() || op->strides.empty()) { // Attributes have not resolved yet. - return false; + return ::tensorflow::Status::OK(); } const auto& input_array = model->GetArray(op->inputs[0]); if (!input_array.has_shape()) { // Yield until the value shape has been resolved. - return false; + return ::tensorflow::Status::OK(); } if (!IsConstantParameterArray(*model, op->inputs[0])) { // Yield until the value is constant. - return false; + return ::tensorflow::Status::OK(); } CHECK(!output_array.buffer); @@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) { DeleteOpAndArraysIfUnused(model, it->get()); - return true; + *modified = true; + return ::tensorflow::Status::OK(); } } // namespace toco |