diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc index 9d8bd4fc39..8853ed87e6 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc @@ -52,14 +52,18 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>(); std::vector<int> src_coord(num_input_axes); std::vector<int> stop_for_axis(num_input_axes); + const auto strided_slice_params = + tflite::strided_slice::BuildStridedSliceParams( + op.begin_mask, op.end_mask, op.shrink_axis_mask, op.start_indices, + op.stop_indices, op.strides); + for (int axis = 0; axis < num_input_axes; axis++) { - int start = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(), - axis); - src_coord[axis] = start; + int start_index = tflite::strided_slice::StartForAxis( + strided_slice_params, ToRuntimeShape(input_array.shape()), axis); + src_coord[axis] = start_index; stop_for_axis[axis] = tflite::strided_slice::StopForAxis( - op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides, - input_shape.dims().data(), axis, start); + strided_slice_params, ToRuntimeShape(input_array.shape()), axis, + start_index); } // In order to handle any number (N) of dimensions, we copy elements one by @@ -86,8 +90,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array, if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) { // Reset axis and set carry src_coord[axis] = tflite::strided_slice::StartForAxis( - op.begin_mask, op.start_indices, op.strides, - input_shape.dims().data(), axis); + strided_slice_params, ToRuntimeShape(input_shape), axis); carry = true; } else { carry = false; |