aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 8853ed87e6..99c5a64662 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -103,11 +103,14 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
} // anonymous namespace
-bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveConstantStridedSlice::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto it = model->operators.begin() + op_index;
const auto* base_op = it->get();
if (base_op->type != OperatorType::kStridedSlice) {
- return false;
+ return ::tensorflow::Status::OK();
}
const StridedSliceOperator* op =
@@ -117,28 +120,28 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
auto& output_array = model->GetArray(op->outputs[0]);
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
- return false;
+ return ::tensorflow::Status::OK();
}
if (!output_array.has_shape()) {
// Yield until the output shape has been set by PropagateFixedShapes
- return false;
+ return ::tensorflow::Status::OK();
}
if (op->start_indices.empty() || op->stop_indices.empty() ||
op->strides.empty()) {
// Attributes have not resolved yet.
- return false;
+ return ::tensorflow::Status::OK();
}
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until the value shape has been resolved.
- return false;
+ return ::tensorflow::Status::OK();
}
if (!IsConstantParameterArray(*model, op->inputs[0])) {
// Yield until the value is constant.
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK(!output_array.buffer);
@@ -164,7 +167,8 @@ bool ResolveConstantStridedSlice::Run(Model* model, std::size_t op_index) {
DeleteOpAndArraysIfUnused(model, it->get());
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco