diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-23 12:19:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-23 12:25:40 -0700 |
commit | 049fc23966eeef02a0945ddb80ae5f40592b90c1 (patch) | |
tree | 63c958a2042888293ea8c42c0d332874d661f360 /tensorflow/contrib/gan | |
parent | c26b95e707ed2304e2e50f51f6751c9b9cb87f1c (diff) |
Extend random pool to work with arbitrarily nested tensor structures.
PiperOrigin-RevId: 205703156
Diffstat (limited to 'tensorflow/contrib/gan')
3 files changed, 38 insertions, 21 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 781e4ae4d7..7e6cb72485 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -257,12 +257,15 @@ py_library( py_test( name = "random_tensor_pool_test", srcs = ["python/features/python/random_tensor_pool_test.py"], + shard_count = 6, srcs_version = "PY2AND3", deps = [ ":random_tensor_pool", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 9e4ec59e70..ca2d724b49 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -36,16 +36,15 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops +from tensorflow.python.util import nest __all__ = [ 'tensor_pool', ] -def _to_tuple(x): - if isinstance(x, (list, tuple)): - return tuple(x) - return (x,) +def _to_list(x): + return [x] if isinstance(x, ops.Tensor) else list(x) def tensor_pool(input_values, @@ -63,8 +62,8 @@ def tensor_pool(input_values, `pool_size` = 0 or `pooling_probability` = 0. Args: - input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read - values to be pooled. + input_values: An arbitrarily nested structure of `tf.Tensors`, from which to + read values to be pooled. pool_size: An integer specifying the maximum size of the pool. Defaults to 50. pooling_probability: A float `Tensor` specifying the probability of getting @@ -72,9 +71,10 @@ def tensor_pool(input_values, name: A string prefix for the name scope for all tensorflow ops. Returns: - A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx - `input_values`) which is with given probability either the `input_values` or - a randomly chosen sample that was previously inserted in the pool. + A nested structure of `Tensor` objects with the same structure as + `input_values`. With the given probability, the Tensor values are either the + same as in `input_values` or a randomly chosen sample that was previously + inserted in the pool. Raises: ValueError: If `pool_size` is negative. @@ -86,11 +86,10 @@ def tensor_pool(input_values, return input_values original_input_values = input_values - input_values = _to_tuple(input_values) + input_values = nest.flatten(input_values) - with ops.name_scope( - '{}_pool_queue'.format(name), - values=input_values + (pooling_probability,)): + with ops.name_scope('{}_pool_queue'.format(name), + values=input_values + [pooling_probability]): pool_queue = data_flow_ops.RandomShuffleQueue( capacity=pool_size, min_after_dequeue=0, @@ -112,10 +111,10 @@ def tensor_pool(input_values, def _get_input_value_pooled(): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): - return tuple(array_ops.identity(v) for v in input_values) + return [array_ops.identity(v) for v in input_values] def _get_random_pool_value_and_enqueue_input(): - dequeue_values = _to_tuple(pool_queue.dequeue()) + dequeue_values = _to_list(pool_queue.dequeue()) with ops.control_dependencies(dequeue_values): enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): @@ -124,7 +123,7 @@ def tensor_pool(input_values, return control_flow_ops.cond(prob, lambda: dequeue_values, lambda: input_values) - output_values = _to_tuple(control_flow_ops.cond( + output_values = _to_list(control_flow_ops.cond( pool_queue.size() < pool_size, _get_input_value_pooled, _get_random_pool_value_and_enqueue_input)) @@ -132,8 +131,4 @@ def tensor_pool(input_values, for input_value, output_value in zip(input_values, output_values): output_value.set_shape(input_value.shape) - if isinstance(original_input_values, list): - return list(output_values) - elif isinstance(original_input_values, tuple): - return output_values - return output_values[0] + return nest.pack_sequence_as(original_input_values, output_values) diff --git a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index d8cf549cf7..08584dcd65 100644 --- a/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -21,7 +21,9 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -111,6 +113,23 @@ class TensorPoolTest(test.TestCase): self.assertEqual(len(outs), len(input_values)) self.assertEqual(outs[1] - outs[0], 1) + def test_pool_preserves_shape(self): + t = constant_op.constant(1) + input_values = [[t, t, t], (t, t), t] + output_values = tensor_pool(input_values, pool_size=5) + print('stuff: ', output_values) + # Overall shape. + self.assertIsInstance(output_values, list) + self.assertEqual(3, len(output_values)) + # Shape of first element. + self.assertIsInstance(output_values[0], list) + self.assertEqual(3, len(output_values[0])) + # Shape of second element. + self.assertIsInstance(output_values[1], tuple) + self.assertEqual(2, len(output_values[1])) + # Shape of third element. + self.assertIsInstance(output_values[2], ops.Tensor) + if __name__ == '__main__': test.main() |