diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/slice_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/slice_op_test.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/slice_op_test.py b/tensorflow/python/kernel_tests/slice_op_test.py index f6997e9c61..f415d9e70d 100644 --- a/tensorflow/python/kernel_tests/slice_op_test.py +++ b/tensorflow/python/kernel_tests/slice_op_test.py @@ -217,6 +217,17 @@ class SliceTest(test.TestCase): self.assertEqual(expected_val.shape, slice_t.get_shape()) self.assertEqual(expected_val.shape, slice2_t.get_shape()) + def testPartialShapeInference(self): + z = array_ops.zeros((1, 2, 3)) + self.assertAllEqual(z.get_shape().as_list(), [1, 2, 3]) + + m1 = array_ops.slice(z, [0, 0, 0], [-1, -1, -1]) + self.assertAllEqual(m1.get_shape().as_list(), [1, 2, 3]) + + m2 = array_ops.slice(z, [0, 0, 0], [constant_op.constant(1) + 0, 2, -1]) + self.assertAllEqual(m2.get_shape().as_list(), [None, 2, None]) + + def _testGradientSlice(self, input_shape, slice_begin, slice_size): with self.test_session(use_gpu=True): num_inputs = np.prod(input_shape) |