aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/slice_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/slice_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/slice_op_test.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py
index f6997e9c61..f415d9e70d 100644
--- a/tensorflow/python/kernel_tests/slice_op_test.py
+++ b/tensorflow/python/kernel_tests/slice_op_test.py
@@ -217,6 +217,17 @@ class SliceTest(test.TestCase):
self.assertEqual(expected_val.shape, slice_t.get_shape())
self.assertEqual(expected_val.shape, slice2_t.get_shape())
+ def testPartialShapeInference(self):
+ z = array_ops.zeros((1, 2, 3))
+ self.assertAllEqual(z.get_shape().as_list(), [1, 2, 3])
+
+ m1 = array_ops.slice(z, [0, 0, 0], [-1, -1, -1])
+ self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3])
+
+ m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1])
+ self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None])
+
+
def _testGradientSlice(self, input_shape, slice_begin, slice_size):
with self.test_session(use_gpu=True):
num_inputs = np.prod(input_shape)