aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-23 12:19:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-23 12:25:40 -0700
commit049fc23966eeef02a0945ddb80ae5f40592b90c1 (patch)
tree63c958a2042888293ea8c42c0d332874d661f360 /tensorflow/contrib/gan
parentc26b95e707ed2304e2e50f51f6751c9b9cb87f1c (diff)
Extend random pool to work with arbitrarily nested tensor structures.
PiperOrigin-RevId: 205703156
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD3
-rw-r--r--tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py37
-rw-r--r--tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py19
3 files changed, 38 insertions, 21 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 781e4ae4d7..7e6cb72485 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -257,12 +257,15 @@ py_library(
py_test(
name = "random_tensor_pool_test",
srcs = ["python/features/python/random_tensor_pool_test.py"],
+ shard_count = 6,
srcs_version = "PY2AND3",
deps = [
":random_tensor_pool",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
index 9e4ec59e70..ca2d724b49 100644
--- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
+++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
@@ -36,16 +36,15 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.util import nest
__all__ = [
'tensor_pool',
]
-def _to_tuple(x):
- if isinstance(x, (list, tuple)):
- return tuple(x)
- return (x,)
+def _to_list(x):
+ return [x] if isinstance(x, ops.Tensor) else list(x)
def tensor_pool(input_values,
@@ -63,8 +62,8 @@ def tensor_pool(input_values,
`pool_size` = 0 or `pooling_probability` = 0.
Args:
- input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read
- values to be pooled.
+ input_values: An arbitrarily nested structure of `tf.Tensors`, from which to
+ read values to be pooled.
pool_size: An integer specifying the maximum size of the pool. Defaults to
50.
pooling_probability: A float `Tensor` specifying the probability of getting
@@ -72,9 +71,10 @@ def tensor_pool(input_values,
name: A string prefix for the name scope for all tensorflow ops.
Returns:
- A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx
- `input_values`) which is with given probability either the `input_values` or
- a randomly chosen sample that was previously inserted in the pool.
+ A nested structure of `Tensor` objects with the same structure as
+ `input_values`. With the given probability, the Tensor values are either the
+ same as in `input_values` or a randomly chosen sample that was previously
+ inserted in the pool.
Raises:
ValueError: If `pool_size` is negative.
@@ -86,11 +86,10 @@ def tensor_pool(input_values,
return input_values
original_input_values = input_values
- input_values = _to_tuple(input_values)
+ input_values = nest.flatten(input_values)
- with ops.name_scope(
- '{}_pool_queue'.format(name),
- values=input_values + (pooling_probability,)):
+ with ops.name_scope('{}_pool_queue'.format(name),
+ values=input_values + [pooling_probability]):
pool_queue = data_flow_ops.RandomShuffleQueue(
capacity=pool_size,
min_after_dequeue=0,
@@ -112,10 +111,10 @@ def tensor_pool(input_values,
def _get_input_value_pooled():
enqueue_op = pool_queue.enqueue(input_values)
with ops.control_dependencies([enqueue_op]):
- return tuple(array_ops.identity(v) for v in input_values)
+ return [array_ops.identity(v) for v in input_values]
def _get_random_pool_value_and_enqueue_input():
- dequeue_values = _to_tuple(pool_queue.dequeue())
+ dequeue_values = _to_list(pool_queue.dequeue())
with ops.control_dependencies(dequeue_values):
enqueue_op = pool_queue.enqueue(input_values)
with ops.control_dependencies([enqueue_op]):
@@ -124,7 +123,7 @@ def tensor_pool(input_values,
return control_flow_ops.cond(prob, lambda: dequeue_values,
lambda: input_values)
- output_values = _to_tuple(control_flow_ops.cond(
+ output_values = _to_list(control_flow_ops.cond(
pool_queue.size() < pool_size, _get_input_value_pooled,
_get_random_pool_value_and_enqueue_input))
@@ -132,8 +131,4 @@ def tensor_pool(input_values,
for input_value, output_value in zip(input_values, output_values):
output_value.set_shape(input_value.shape)
- if isinstance(original_input_values, list):
- return list(output_values)
- elif isinstance(original_input_values, tuple):
- return output_values
- return output_values[0]
+ return nest.pack_sequence_as(original_input_values, output_values)
diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
index d8cf549cf7..08584dcd65 100644
--- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
+++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py
@@ -21,7 +21,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -111,6 +113,23 @@ class TensorPoolTest(test.TestCase):
self.assertEqual(len(outs), len(input_values))
self.assertEqual(outs[1] - outs[0], 1)
+ def test_pool_preserves_shape(self):
+ t = constant_op.constant(1)
+ input_values = [[t, t, t], (t, t), t]
+ output_values = tensor_pool(input_values, pool_size=5)
+ print('stuff: ', output_values)
+ # Overall shape.
+ self.assertIsInstance(output_values, list)
+ self.assertEqual(3, len(output_values))
+ # Shape of first element.
+ self.assertIsInstance(output_values[0], list)
+ self.assertEqual(3, len(output_values[0]))
+ # Shape of second element.
+ self.assertIsInstance(output_values[1], tuple)
+ self.assertEqual(2, len(output_values[1]))
+ # Shape of third element.
+ self.assertIsInstance(output_values[2], ops.Tensor)
+
if __name__ == '__main__':
test.main()