aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/strided_slice.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-06-27 17:30:52 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit9bcbfc1636d98e99d821c1a5292a5fe70663ccdb (patch)
treed4066528dc34ab1e6b0cb3abcbee6ff6165429b0 /tensorflow/contrib/lite/kernels/strided_slice.cc
parente3eac29c3498cfc64bcef5022b67aaf04d8b23da (diff)
Ignore stop indices when shrink_axis_mask is set in tf.lite StridedSlice implementation.
Due to an issue with negative StridedSlice indices in TensorFlow, the end indices can specify degenerate slices when negative indices are used to shrink an axis (e.g. for tf.range(4)[-1], start is -1, end is 0, and stride is 1). This fix works around the issue by ignoring stop indices entirely when an axis is shrinking, since in order to be shrunk the length is by definition 1. Fixes Issue #19260. PiperOrigin-RevId: 202398678
Diffstat (limited to 'tensorflow/contrib/lite/kernels/strided_slice.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/strided_slice.cc27
1 files changed, 19 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc
index 725dd8105a..bed2117f9a 100644
--- a/tensorflow/contrib/lite/kernels/strided_slice.cc
+++ b/tensorflow/contrib/lite/kernels/strided_slice.cc
@@ -121,10 +121,19 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
int32_t begin = GetBeginValueAtIndex(op_context, idx);
int32_t end = GetEndValueAtIndex(op_context, idx);
+ // When shrinking an axis, the end position does not matter (and can be
+ // incorrect when negative indexing is used, see Issue #19260). Always use
+ // begin + 1 to generate a length 1 slice, since begin has
+ // already been adjusted for negative indices by GetBeginValueAtIndex.
+ const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
+ if (shrink_axis) {
+ end = begin + 1;
+ }
+
// This is valid for both positive and negative strides
int32_t dim_shape = ceil((end - begin) / static_cast<float>(stride));
dim_shape = dim_shape < 0 ? 0 : dim_shape;
- if (!(op_context->params->shrink_axis_mask & (1 << idx))) {
+ if (!shrink_axis) {
output_shape_vector.push_back(dim_shape);
}
}
@@ -204,13 +213,15 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
int begin_mask =
ReverseMaskBits(op_context.params->begin_mask, op_context.dims);
int end_mask = ReverseMaskBits(op_context.params->end_mask, op_context.dims);
-
-#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
- kernel_type::StridedSlice(GetTensorData<data_type>(op_context.input), \
- GetTensorDims(op_context.input), begin_mask, \
- end_mask, starts, stops, strides, \
- GetTensorData<data_type>(op_context.output), \
- GetTensorDims(op_context.output))
+ int shrink_axis_mask =
+ ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims);
+
+#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \
+ kernel_type::StridedSlice( \
+ GetTensorData<data_type>(op_context.input), \
+ GetTensorDims(op_context.input), begin_mask, end_mask, shrink_axis_mask, \
+ starts, stops, strides, GetTensorData<data_type>(op_context.output), \
+ GetTensorDims(op_context.output))
switch (op_context.input->type) {
case kTfLiteFloat32: