aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/shape_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/shape_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py18
1 files changed, 10 insertions, 8 deletions
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 52cf904528..a9fc699b21 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -411,14 +411,16 @@ class TileTest(test.TestCase):
self.assertEqual(7, result)
def testSimple(self):
- with self.test_session():
- inp = np.random.rand(4, 1).astype(np.float32)
- a = constant_op.constant(inp)
- tiled = array_ops.tile(a, [1, 4])
- result = tiled.eval()
- self.assertEqual(result.shape, (4, 4))
- self.assertEqual([4, 4], tiled.get_shape())
- self.assertTrue((result == np.tile(inp, (1, 4))).all())
+ # multiples could be int32 or int64
+ for dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(use_gpu=True):
+ inp = np.random.rand(4, 1).astype(np.float32)
+ a = constant_op.constant(inp)
+ tiled = array_ops.tile(a, constant_op.constant([1, 4], dtype=dtype))
+ result = tiled.eval()
+ self.assertEqual(result.shape, (4, 4))
+ self.assertEqual([4, 4], tiled.get_shape())
+ self.assertTrue((result == np.tile(inp, (1, 4))).all())
def testIdentityTileAndGrad(self):
with self.test_session():