aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc19
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 9d8bd4fc39..8853ed87e6 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
std::vector<int> src_coord(num_input_axes);
std::vector<int> stop_for_axis(num_input_axes);
+ const auto strided_slice_params =
+ tflite::strided_slice::BuildStridedSliceParams(
+ op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices,
+ op.stop_indices, op.strides);
+
for (int axis = 0; axis < num_input_axes; axis++) {
- int start = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
- axis);
- src_coord[axis] = start;
+ int start_index = tflite::strided_slice::StartForAxis(
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
+ src_coord[axis] = start_index;
stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
- op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
- input_shape.dims().data(), axis, start);
+ strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
+ start_index);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(
- op.begin_mask, op.start_indices, op.strides,
- input_shape.dims().data(), axis);
+ strided_slice_params, ToRuntimeShape(input_shape), axis);
carry = true;
} else {
carry = false;