aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc32
1 files changed, 20 insertions, 12 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
index 65132d7d1e..f54f5b42a1 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_strided_slice_attributes.cc
@@ -37,40 +37,47 @@ int PadAttributeArray(Array* attribute_array, std::vector<int> pad_values,
return mask;
}
-bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
+::tensorflow::Status ResolveStridedSliceAttributes::Run(Model* model,
+ std::size_t op_index,
+ bool* modified) {
+ *modified = false;
const auto slice_it = model->operators.begin() + op_index;
auto* slice_op = slice_it->get();
- if (slice_op->type != OperatorType::kStridedSlice) return false;
+ if (slice_op->type != OperatorType::kStridedSlice)
+ return ::tensorflow::Status::OK();
auto* op = static_cast<StridedSliceOperator*>(slice_op);
if (!op->start_indices.empty()) {
// We have already resolved these attributes
- return false;
+ return ::tensorflow::Status::OK();
}
CHECK_EQ(op->inputs.size(), 4);
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// We require the dimensionality of the input to pad the indices
- return false;
+ return ::tensorflow::Status::OK();
}
auto& start_array = model->GetArray(op->inputs[1]);
- if (!start_array.has_shape()) return false;
+ if (!start_array.has_shape()) return ::tensorflow::Status::OK();
if (toco::RequiredBufferSizeForShape(start_array.shape()) > 4) {
// Only 1-4D arrays are supported for now.
- return false;
+ return ::tensorflow::Status::OK();
}
auto& stop_array = model->GetArray(op->inputs[2]);
- if (!stop_array.has_shape()) return false;
+ if (!stop_array.has_shape()) return ::tensorflow::Status::OK();
auto& stride_array = model->GetArray(op->inputs[3]);
- if (!stride_array.has_shape()) return false;
+ if (!stride_array.has_shape()) return ::tensorflow::Status::OK();
- if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[2])) return false;
- if (!IsConstantParameterArray(*model, op->inputs[3])) return false;
+ if (!IsConstantParameterArray(*model, op->inputs[1]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[2]))
+ return ::tensorflow::Status::OK();
+ if (!IsConstantParameterArray(*model, op->inputs[3]))
+ return ::tensorflow::Status::OK();
int num_input_axes = input_array.shape().dimensions_count();
int start_indices_size = start_array.shape().dims(0);
@@ -112,6 +119,7 @@ bool ResolveStridedSliceAttributes::Run(Model* model, std::size_t op_index) {
op->stop_indices = stop_array.GetBuffer<ArrayDataType::kInt32>().data;
op->strides = stride_array.GetBuffer<ArrayDataType::kInt32>().data;
- return true;
+ *modified = true;
+ return ::tensorflow::Status::OK();
}
} // namespace toco