aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/shape_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/shape_ops_test.py')
-rw-r--r--tensorflow/python/kernel_tests/shape_ops_test.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/shape_ops_test.py b/tensorflow/python/kernel_tests/shape_ops_test.py
index 97d61d52af..52cf904528 100644
--- a/tensorflow/python/kernel_tests/shape_ops_test.py
+++ b/tensorflow/python/kernel_tests/shape_ops_test.py
@@ -504,16 +504,16 @@ class TileTest(test.TestCase):
with self.assertRaises(ValueError):
array_ops.tile(a, [[2, 3], [3, 4]]).eval()
- def _RunAndVerifyResult(self, use_gpu):
+ def _RunAndVerifyResult(self, rank, use_gpu):
with self.test_session(use_gpu=use_gpu):
- # Random dims of rank 5
- input_shape = np.random.randint(1, 4, size=5)
+ # Random dims of given rank
+ input_shape = np.random.randint(1, 4, size=rank)
inp = np.random.rand(*input_shape).astype("f")
a = constant_op.constant(
[float(x) for x in inp.ravel(order="C")],
shape=input_shape,
dtype=dtypes.float32)
- multiples = np.random.randint(1, 4, size=5).astype(np.int32)
+ multiples = np.random.randint(1, 4, size=rank).astype(np.int32)
tiled = array_ops.tile(a, multiples)
result = tiled.eval()
self.assertTrue((np.array(multiples) * np.array(inp.shape) == np.array(
@@ -522,10 +522,16 @@ class TileTest(test.TestCase):
self.assertShapeEqual(result, tiled)
def testRandom(self):
+ # test low rank, like 5
for _ in range(5):
- self._RunAndVerifyResult(use_gpu=False)
+ self._RunAndVerifyResult(5, use_gpu=False)
for _ in range(5):
- self._RunAndVerifyResult(use_gpu=True)
+ self._RunAndVerifyResult(5, use_gpu=True)
+ # test high rank, like 10
+ for _ in range(5):
+ self._RunAndVerifyResult(10, use_gpu=False)
+ for _ in range(5):
+ self._RunAndVerifyResult(10, use_gpu=True)
def testGradientSimpleReduction(self):
with self.test_session():