aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
index 6ee231465f..9d8bd4fc39 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_strided_slice.cc
@@ -38,6 +38,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
CHECK_EQ(op.new_axis_mask, 0);
int num_input_axes = op.start_indices.size();
+ CHECK_EQ(num_input_axes, op.start_indices.size());
CHECK_EQ(num_input_axes, op.stop_indices.size());
CHECK_EQ(num_input_axes, op.strides.size());
@@ -49,11 +50,16 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
// Initialize source coordinate
Shape const& input_shape = input_array.shape();
Buffer<Type> const& input_buffer = input_array.GetBuffer<Type>();
- std::vector<int> src_coord(op.start_indices.size());
+ std::vector<int> src_coord(num_input_axes);
+ std::vector<int> stop_for_axis(num_input_axes);
for (int axis = 0; axis < num_input_axes; axis++) {
- src_coord[axis] = tflite::strided_slice::StartForAxis(
+ int start = tflite::strided_slice::StartForAxis(
op.begin_mask, op.start_indices, op.strides, input_shape.dims().data(),
axis);
+ src_coord[axis] = start;
+ stop_for_axis[axis] = tflite::strided_slice::StopForAxis(
+ op.end_mask, op.shrink_axis_mask, op.stop_indices, op.strides,
+ input_shape.dims().data(), axis, start);
}
// In order to handle any number (N) of dimensions, we copy elements one by
@@ -76,9 +82,7 @@ void StridedSlice(StridedSliceOperator const& op, Array const& input_array,
}
// Check if we've overflowed.
- int stop = tflite::strided_slice::StopForAxis(
- op.end_mask, op.stop_indices, op.strides, input_shape.dims().data(),
- axis);
+ int stop = stop_for_axis[axis];
if (tflite::strided_slice::LoopCondition(src_coord[axis], stop, stride)) {
// Reset axis and set carry
src_coord[axis] = tflite::strided_slice::StartForAxis(