aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py')
-rw-r--r--tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py37
1 files changed, 16 insertions, 21 deletions
diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
index 9e4ec59e70..ca2d724b49 100644
--- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
+++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py
@@ -36,16 +36,15 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.util import nest
__all__ = [
'tensor_pool',
]
-def _to_tuple(x):
- if isinstance(x, (list, tuple)):
- return tuple(x)
- return (x,)
+def _to_list(x):
+ return [x] if isinstance(x, ops.Tensor) else list(x)
def tensor_pool(input_values,
@@ -63,8 +62,8 @@ def tensor_pool(input_values,
`pool_size` = 0 or `pooling_probability` = 0.
Args:
- input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read
- values to be pooled.
+ input_values: An arbitrarily nested structure of `tf.Tensors`, from which to
+ read values to be pooled.
pool_size: An integer specifying the maximum size of the pool. Defaults to
50.
pooling_probability: A float `Tensor` specifying the probability of getting
@@ -72,9 +71,10 @@ def tensor_pool(input_values,
name: A string prefix for the name scope for all tensorflow ops.
Returns:
- A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx
- `input_values`) which is with given probability either the `input_values` or
- a randomly chosen sample that was previously inserted in the pool.
+ A nested structure of `Tensor` objects with the same structure as
+ `input_values`. With the given probability, the Tensor values are either the
+ same as in `input_values` or a randomly chosen sample that was previously
+ inserted in the pool.
Raises:
ValueError: If `pool_size` is negative.
@@ -86,11 +86,10 @@ def tensor_pool(input_values,
return input_values
original_input_values = input_values
- input_values = _to_tuple(input_values)
+ input_values = nest.flatten(input_values)
- with ops.name_scope(
- '{}_pool_queue'.format(name),
- values=input_values + (pooling_probability,)):
+ with ops.name_scope('{}_pool_queue'.format(name),
+ values=input_values + [pooling_probability]):
pool_queue = data_flow_ops.RandomShuffleQueue(
capacity=pool_size,
min_after_dequeue=0,
@@ -112,10 +111,10 @@ def tensor_pool(input_values,
def _get_input_value_pooled():
enqueue_op = pool_queue.enqueue(input_values)
with ops.control_dependencies([enqueue_op]):
- return tuple(array_ops.identity(v) for v in input_values)
+ return [array_ops.identity(v) for v in input_values]
def _get_random_pool_value_and_enqueue_input():
- dequeue_values = _to_tuple(pool_queue.dequeue())
+ dequeue_values = _to_list(pool_queue.dequeue())
with ops.control_dependencies(dequeue_values):
enqueue_op = pool_queue.enqueue(input_values)
with ops.control_dependencies([enqueue_op]):
@@ -124,7 +123,7 @@ def tensor_pool(input_values,
return control_flow_ops.cond(prob, lambda: dequeue_values,
lambda: input_values)
- output_values = _to_tuple(control_flow_ops.cond(
+ output_values = _to_list(control_flow_ops.cond(
pool_queue.size() < pool_size, _get_input_value_pooled,
_get_random_pool_value_and_enqueue_input))
@@ -132,8 +131,4 @@ def tensor_pool(input_values,
for input_value, output_value in zip(input_values, output_values):
output_value.set_shape(input_value.shape)
- if isinstance(original_input_values, list):
- return list(output_values)
- elif isinstance(original_input_values, tuple):
- return output_values
- return output_values[0]
+ return nest.pack_sequence_as(original_input_values, output_values)