diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/shape_ops_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/shape_ops_test.py | 18 |
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py index 52cf904528..a9fc699b21 100644 --- a/tensorflow/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/python/kernel_tests/shape_ops_test.py @@ -411,14 +411,16 @@ class TileTest(test.TestCase): self.assertEqual(7, result) def testSimple(self): - with self.test_session(): - inp = np.random.rand(4, 1).astype(np.float32) - a = constant_op.constant(inp) - tiled = array_ops.tile(a, [1, 4]) - result = tiled.eval() - self.assertEqual(result.shape, (4, 4)) - self.assertEqual([4, 4], tiled.get_shape()) - self.assertTrue((result == np.tile(inp, (1, 4))).all()) + # multiples could be int32 or int64 + for dtype in [dtypes.int32, dtypes.int64]: + with self.test_session(use_gpu=True): + inp = np.random.rand(4, 1).astype(np.float32) + a = constant_op.constant(inp) + tiled = array_ops.tile(a, constant_op.constant([1, 4], dtype=dtype)) + result = tiled.eval() + self.assertEqual(result.shape, (4, 4)) + self.assertEqual([4, 4], tiled.get_shape()) + self.assertTrue((result == np.tile(inp, (1, 4))).all()) def testIdentityTileAndGrad(self): with self.test_session(): |