diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/dynamic_stitch_op_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/dynamic_stitch_op_test.py | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py index c4d4ce780b..49b9569e2b 100644 --- a/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_stitch_op_test.py @@ -104,6 +104,27 @@ class DynamicStitchTestBase(object): # Dimension 0 is max(flatten(indices))+1. self.assertEqual([8, 2], stitched_t.get_shape().as_list()) + def testZeroSizeTensor(self): + with self.test_session(use_gpu=True): + indices = [ + constant_op.constant([0, 4, 7]), + constant_op.constant([1, 6]), + constant_op.constant([2, 3, 5]), + array_ops.zeros([0], dtype=dtypes.int32) + ] + data = [ + constant_op.constant([[0, 1], [40, 41], [70, 71]]), + constant_op.constant([[10, 11], [60, 61]]), + constant_op.constant([[20, 21], [30, 31], [50, 51]]), + array_ops.zeros([0, 2], dtype=dtypes.int32) + ] + stitched_t = self.stitch_op(indices, data) + stitched_val = stitched_t.eval() + self.assertAllEqual([[0, 1], [10, 11], [20, 21], [30, 31], [40, 41], + [50, 51], [60, 61], [70, 71]], stitched_val) + # Dimension 0 is max(flatten(indices))+1. + self.assertEqual([8, 2], stitched_t.get_shape().as_list()) + def testHigherRank(self): with self.test_session(use_gpu=True) as sess: indices = [ |