diff options
author | RJ Ryan <rjryan@google.com> | 2018-06-27 17:30:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-27 17:33:28 -0700 |
commit | 07f61ee48784c8765006ea6ce6abd467cfe47a9e (patch) | |
tree | 1f5bb3694aca5bad9de2b6bc45e86bd01588285a /tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | |
parent | 0e9d608fe3189ebccff1995c2b1f8a86009bbceb (diff) |
Ignore stop indices when shrink_axis_mask is set in tf.lite StridedSlice implementation.
Due to an issue with negative StridedSlice indices in TensorFlow, the end indices can specify degenerate slices when negative indices are used to shrink an axis (e.g. for tf.range(4)[-1], start is -1, end is 0, and stride is 1). This fix works around the issue by ignoring stop indices entirely when an axis is shrinking, since in order to be shrunk the length is by definition 1.
Fixes Issue #19260.
PiperOrigin-RevId: 202398678
Diffstat (limited to 'tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index cee14b257f..82b3ab96fe 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1291,8 +1291,8 @@ void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) { op->begin_mask, op->start_indices, op->strides, input_array.shape().dims().data(), axis); int stop_index = tflite::strided_slice::StopForAxis( - op->end_mask, op->stop_indices, op->strides, - input_array.shape().dims().data(), axis); + op->end_mask, op->shrink_axis_mask, op->stop_indices, op->strides, + input_array.shape().dims().data(), axis, start_index); int dim_size = ceil(static_cast<float>(stop_index - start_index) / op->strides[axis]); |