aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py')
-rw-r--r--tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
index d8cf549cf7..08584dcd65 100644
--- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
+++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
@@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -111,6 +113,23 @@ class TensorPoolTest(test.TestCase):
self.assertEqual(len(outs), len(input_values))
self.assertEqual(outs[1] - outs[0], 1)
+ def test_pool_preserves_shape(self):
+ t = constant_op.constant(1)
+ input_values = [[t, t, t], (t, t), t]
+ output_values = tensor_pool(input_values, pool_size=5)
+ print('stuff: ', output_values)
+ # Overall shape.
+ self.assertIsInstance(output_values, list)
+ self.assertEqual(3, len(output_values))
+ # Shape of first element.
+ self.assertIsInstance(output_values[0], list)
+ self.assertEqual(3, len(output_values[0]))
+ # Shape of second element.
+ self.assertIsInstance(output_values[1], tuple)
+ self.assertEqual(2, len(output_values[1]))
+ # Shape of third element.
+ self.assertIsInstance(output_values[2], ops.Tensor)
+
if __name__ == '__main__':
test.main()