diff options
author | 2018-08-17 15:02:49 -0700 | |
---|---|---|
committer | 2018-08-17 15:05:49 -0700 | |
commit | dfeee0c45d4127ea0e20fb0a1071c7d59658c169 (patch) | |
tree | b8fa7c0599068e95fa5f6a70b168e08391238b3e /tensorflow/core | |
parent | e2a0700bbf90f0738e43fbc29ec758adc364347c (diff) |
Fix the bug that StridedSlice loses static dimension when dim==0.
PiperOrigin-RevId: 209212830
Diffstat (limited to 'tensorflow/core')
-rw-r--r-- | tensorflow/core/ops/array_ops_test.cc | 18 | ||||
-rw-r--r-- | tensorflow/core/util/strided_slice_op.cc | 2 |
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index c15409a246..03dab390a7 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -1620,6 +1620,24 @@ TEST(ArrayOpsTest, Slice_ShapeFn) { INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]"); } +TEST(ArrayOpsTest, StridedSlice_ShapeFn) { + ShapeInferenceTestOp op("StridedSlice"); + TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice") + .Input("input", 0, DT_FLOAT) + .Input("begin", 1, DT_INT32) + .Input("end", 2, DT_INT32) + .Input("strides", 3, DT_INT32) + .Attr("shrink_axis_mask", 1) + .Finalize(&op.node_def)); + op.input_tensors.resize(4); + Tensor strides = test::AsTensor<int32>({1}); + op.input_tensors[3] = &strides; + // Slicing on the 0-th dimension. + INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]"); + // Slicing on the 0-th dimension. This time some of the result dimension is 0. + INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]"); +} + TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) { ShapeInferenceTestOp op("StridedSliceGrad"); op.input_tensors.resize(5); diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc index aca60b942d..ad8a44a518 100644 --- a/tensorflow/core/util/strided_slice_op.cc +++ b/tensorflow/core/util/strided_slice_op.cc @@ -326,7 +326,7 @@ Status ValidateStridedSliceOp( // Even if we don't have values for begin or end, we do know that this // dimension covers the whole interval. If we have shape information for // this dimension, that tells us the interval length. - if (dim_i > 0) { + if (dim_i >= 0) { if (stride_i < 0) { interval_length = -dim_i; } else { |