aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2018-06-27 17:30:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 17:33:28 -0700
commit07f61ee48784c8765006ea6ce6abd467cfe47a9e (patch)
tree1f5bb3694aca5bad9de2b6bc45e86bd01588285a /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
parent0e9d608fe3189ebccff1995c2b1f8a86009bbceb (diff)
Ignore stop indices when shrink_axis_mask is set in tf.lite StridedSlice implementation.
Due to an issue with negative StridedSlice indices in TensorFlow, the end indices can specify degenerate slices when negative indices are used to shrink an axis (e.g. for tf.range(4)[-1], start is -1, end is 0, and stride is 1). This fix works around the issue by ignoring stop indices entirely when an axis is shrinking, since in order to be shrunk the length is by definition 1. Fixes Issue #19260. PiperOrigin-RevId: 202398678
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index cee14b257f..82b3ab96fe 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -1291,8 +1291,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
op->begin_mask, op->start_indices, op->strides,
input_array.shape().dims().data(), axis);
int stop_index = tflite::strided_slice::StopForAxis(
- op->end_mask, op->stop_indices, op->strides,
- input_array.shape().dims().data(), axis);
+ op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides,
+ input_array.shape().dims().data(), axis, start_index);
int dim_size =
ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]);