aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-17 15:02:49 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-17 15:05:49 -0700
commitdfeee0c45d4127ea0e20fb0a1071c7d59658c169 (patch)
treeb8fa7c0599068e95fa5f6a70b168e08391238b3e /tensorflow/core
parente2a0700bbf90f0738e43fbc29ec758adc364347c (diff)
Fix the bug that StridedSlice loses static dimension when dim==0.
PiperOrigin-RevId: 209212830
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/ops/array_ops_test.cc18
-rw-r--r--tensorflow/core/util/strided_slice_op.cc2
2 files changed, 19 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index c15409a246..03dab390a7 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -1620,6 +1620,24 @@ TEST(ArrayOpsTest, Slice_ShapeFn) {
INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
}
+TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
+ ShapeInferenceTestOp op("StridedSlice");
+ TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
+ .Input("input", 0, DT_FLOAT)
+ .Input("begin", 1, DT_INT32)
+ .Input("end", 2, DT_INT32)
+ .Input("strides", 3, DT_INT32)
+ .Attr("shrink_axis_mask", 1)
+ .Finalize(&op.node_def));
+ op.input_tensors.resize(4);
+ Tensor strides = test::AsTensor<int32>({1});
+ op.input_tensors[3] = &strides;
+ // Slicing on the 0-th dimension.
+ INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
+ // Slicing on the 0-th dimension. This time some of the result dimension is 0.
+ INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
+}
+
TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
ShapeInferenceTestOp op("StridedSliceGrad");
op.input_tensors.resize(5);
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index aca60b942d..ad8a44a518 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -326,7 +326,7 @@ Status ValidateStridedSliceOp(
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
- if (dim_i > 0) {
+ if (dim_i >= 0) {
if (stride_i < 0) {
interval_length = -dim_i;
} else {