From 6b1c2cc8306322976d0738f4b799760dce29b23b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 1 Dec 2016 14:24:51 -0800 Subject: Add ability to conditionally batch in 'tf.train.batch`. Change: 140778005 --- .../contrib/training/python/training/bucket_ops.py | 2 +- tensorflow/python/ops/io_ops.py | 6 +- tensorflow/python/training/input.py | 567 ++++++++++++++++----- tensorflow/python/training/input_test.py | 310 +++++++++++ 4 files changed, 765 insertions(+), 120 deletions(-) diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 3f397d2401..199764e6b9 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -152,7 +152,7 @@ def bucket(tensors, with ops.name_scope(name, "bucket", tensor_list) as name: tensor_list = _validate_bucket(tensor_list) (tensor_list, sparse_info) = _store_sparse_tensors( - tensor_list, enqueue_many=False) + tensor_list, enqueue_many=False, keep_input=True) # Round-trip batch_size to a tensor, and possibly back batch_size = ops.convert_to_tensor( diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 99f992ff5f..e7af6bfe2d 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -131,12 +131,16 @@ single subgraph producing examples but you want to run it in *N* threads (where you increase *N* until it can keep the queue full). Use [`batch_join`](#batch_join) or [`shuffle_batch_join`](#shuffle_batch_join) if you have *N* different subgraphs producing examples to batch and you -want them run by *N* threads. +want them run by *N* threads. Use `maybe_*` to enqueue conditionally. @@batch +@@maybe_batch @@batch_join +@@maybe_batch_join @@shuffle_batch +@@maybe_shuffle_batch @@shuffle_batch_join +@@maybe_shuffle_batch_join """ from __future__ import absolute_import diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index 4fdda96860..67dc66ea4c 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -407,7 +407,8 @@ def _as_original_type(original_tensors, tensor_list): return tensor_list -def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None): +def _store_sparse_tensors(tensor_list, enqueue_many, keep_input, + shared_map_ops=None): """Store SparseTensors for feeding into batch, etc. If `shared_map_ops` is provided, the underlying `SparseTensorsMap` objects @@ -425,6 +426,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None): Args: tensor_list: List of `Tensor` and `SparseTensor` objects. enqueue_many: Python `Boolean`. + keep_input: Bool tensor. If False, don't store. shared_map_ops: (optional) List of `Operation` objects from a previous call to `_store_sparse_tensors`. If not `None`, the op types should be one of `AddSparseToTensorsMap` or `AddManySparseToTensorsMap` in the @@ -443,40 +445,64 @@ def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None): rank = t.shape.get_shape().with_rank(1)[0] if enqueue_many: rank -= 1 - # If a shared map_op was provided, use that. Otherwise use the name of + # If a shared map_op was provided, use that. Otherwise use the name of # the operation used to store the SparseTensor. return _SparseMetaData( sparse=True, map_op=map_op or storing_op, rank=rank) def _maybe_store(t, shared_map_op): + """Store Sparse tensor, if necessary.""" if not isinstance(t, sparse_tensor.SparseTensor): return t map_op_name = shared_map_op.name if shared_map_op else None - return (_store_many_sparse(t, shared_name=map_op_name) if enqueue_many - else _store_sparse(t, shared_name=map_op_name)) + def _maybe_store_sparse(t, map_op_name, keep_input): + return control_flow_ops.cond( + keep_input, + lambda: _store_sparse(t, shared_name=map_op_name), + lambda: constant_op.constant(-1, dtypes.int64)) + def _maybe_store_many_sparse(t, map_op_name, keep_input): + out_tensor = control_flow_ops.cond( + keep_input, + lambda: _store_many_sparse(t, shared_name=map_op_name), + lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64)) + out_tensor.set_shape([None]) # necessary when t.ndims is unknown + return out_tensor + store_f = _maybe_store_many_sparse if enqueue_many else _maybe_store_sparse + return store_f(t, map_op_name, keep_input) stored_list = [ _maybe_store(t, shared_map_op) for t, shared_map_op in zip(tensor_list, maybe_shared_map_ops)] + # Since the output of `_store{_many}_sparse is wrapped in a tf.cond `Merge`, + # we can't just get the Op of the resulting tensor. + def _sparse_op(stored): + for input_tensor in stored.op.inputs: + if input_tensor.op.type in ("AddSparseToTensorsMap", + "AddManySparseToTensorsMap"): + return input_tensor.op + # If there was no sparse input, then the original stored Tensor wasn't + # sparse and we can just return the original Tensor's Op. + return stored.op sparse_info_list = [ - _sparse_meta_data(t, stored.op, shared_map_op) + _sparse_meta_data(t, _sparse_op(stored), shared_map_op) for t, stored, shared_map_op in zip(tensor_list, stored_list, maybe_shared_map_ops)] - # expand dims of stored tensors by 1 for proper enqueue shape + # Expand dims of stored tensors by 1 for proper enqueue shape stored_list = [ array_ops.expand_dims(s, [-1]) if s_info.sparse else s for s, s_info in zip(stored_list, sparse_info_list)] return stored_list, sparse_info_list -def _store_sparse_tensors_join(tensor_list_list, enqueue_many): +def _store_sparse_tensors_join(tensor_list_list, enqueue_many, keep_input): """Store SparseTensors for feeding into batch_join, etc.""" (s0, sparse_info_list) = _store_sparse_tensors( - tensor_list_list[0], enqueue_many) + tensor_list_list[0], enqueue_many, keep_input) stored_list_list = [s0] for tensor_list in tensor_list_list[1:]: s, sparse_info_candidate = _store_sparse_tensors( - tensor_list, enqueue_many, [st.map_op for st in sparse_info_list]) + tensor_list, enqueue_many, keep_input, + [st.map_op for st in sparse_info_list]) if sparse_info_list != sparse_info_candidate: raise ValueError("Inconsistent SparseTensors list: %s vs. %s" % (tensor_list_list[0], tensor_list)) @@ -498,8 +524,7 @@ def _restore_sparse_tensors(stored_list, sparse_info_list): sparse_handles=array_ops.squeeze(s, [1]), rank=(info.rank + 1).value) if info.sparse else s - for (s, info) - in zip(stored_list, sparse_info_list)] + for (s, info) in zip(stored_list, sparse_info_list)] return tensors if received_sequence else tensors[0] @@ -518,6 +543,12 @@ def _validate_join(tensor_list_list): return tensor_list_list +def _validate_tensor_or_none(tensor_or_none): + if tensor_or_none is not None: + return ops.convert_to_tensor(tensor_or_none) + return tensor_or_none + + def _dtypes(tensor_list_list): all_types = [[t.dtype for t in tl] for tl in tensor_list_list] types = all_types[0] @@ -571,19 +602,35 @@ def _shapes(tensor_list_list, shapes, enqueue_many): return shapes -def _enqueue_join(queue, tensor_list_list, enqueue_many): +def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input): + """Enqueue `tensor_list_list` in `queue`.""" if enqueue_many: - enqueue_ops = [queue.enqueue_many(tl) for tl in tensor_list_list] + enqueue_fn = queue.enqueue_many + else: + enqueue_fn = queue.enqueue + if keep_input is None: + enqueue_ops = [enqueue_fn(tl) for tl in tensor_list_list] else: - enqueue_ops = [queue.enqueue(tl) for tl in tensor_list_list] + enqueue_ops = [control_flow_ops.cond( + keep_input, + lambda: enqueue_fn(tl), + control_flow_ops.no_op) for tl in tensor_list_list] queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops)) -def _enqueue(queue, tensor_list, threads, enqueue_many): +def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input): + """Enqueue `tensor_list` in `queue`.""" if enqueue_many: - enqueue_ops = [queue.enqueue_many(tensor_list)] * threads + enqueue_fn = queue.enqueue_many else: - enqueue_ops = [queue.enqueue(tensor_list)] * threads + enqueue_fn = queue.enqueue + if keep_input is None: + enqueue_ops = [enqueue_fn(tensor_list)] * threads + else: + enqueue_ops = [control_flow_ops.cond( + keep_input, + lambda: enqueue_fn(tensor_list), + control_flow_ops.no_op)] * threads queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops)) @@ -592,6 +639,144 @@ def _which_queue(dynamic_pad): else data_flow_ops.FIFOQueue) +def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32, + enqueue_many=False, shapes=None, dynamic_pad=False, + allow_smaller_final_batch=False, shared_name=None, + name=None): + """Helper function for `batch` and `maybe_batch`.""" + tensor_list = _as_tensor_list(tensors) + with ops.name_scope(name, "batch", list(tensor_list) + [keep_input]) as name: + tensor_list = _validate(tensor_list) + keep_input = _validate_tensor_or_none(keep_input) + (tensor_list, sparse_info) = _store_sparse_tensors( + tensor_list, enqueue_many, keep_input) + types = _dtypes([tensor_list]) + shapes = _shapes([tensor_list], shapes, enqueue_many) + # TODO(josh11b,mrry): Switch to BatchQueue once it is written. + queue = _which_queue(dynamic_pad)( + capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) + _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input) + summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity), + math_ops.cast(queue.size(), dtypes.float32) * + (1. / capacity)) + + if allow_smaller_final_batch: + dequeued = queue.dequeue_up_to(batch_size, name=name) + else: + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _restore_sparse_tensors(dequeued, sparse_info) + return _as_original_type(tensors, dequeued) + + +# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be +# a multiple of len(tensor_list_list)?) parameter, to address the use +# case where you want more parallelism than you can support different +# readers (either because you don't have that many files or can't +# read that many files in parallel due to the number of seeks required). +# Once this is done, batch() can be written as a call to batch_join(). +def _batch_join(tensors_list, batch_size, keep_input, capacity=32, + enqueue_many=False, shapes=None, dynamic_pad=False, + allow_smaller_final_batch=False, shared_name=None, name=None): + """Helper function for `batch_join` and `maybe_batch_join`.""" + tensor_list_list = _as_tensor_list_list(tensors_list) + with ops.name_scope(name, "batch_join", + _flatten(tensor_list_list) + [keep_input]) as name: + tensor_list_list = _validate_join(tensor_list_list) + keep_input = _validate_tensor_or_none(keep_input) + tensor_list_list, sparse_info = _store_sparse_tensors_join( + tensor_list_list, enqueue_many, keep_input) + types = _dtypes(tensor_list_list) + shapes = _shapes(tensor_list_list, shapes, enqueue_many) + # TODO(josh11b,mrry): Switch to BatchQueue once it is written. + queue = _which_queue(dynamic_pad)( + capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) + _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input) + summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity), + math_ops.cast(queue.size(), dtypes.float32) * + (1. / capacity)) + + if allow_smaller_final_batch: + dequeued = queue.dequeue_up_to(batch_size, name=name) + else: + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _restore_sparse_tensors(dequeued, sparse_info) + # tensors_list was validated to not be empty. + return _as_original_type(tensors_list[0], dequeued) + + +def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, + keep_input, num_threads=1, seed=None, enqueue_many=False, + shapes=None, allow_smaller_final_batch=False, + shared_name=None, name=None): + """Helper function for `shuffle_batch` and `maybe_shuffle_batch`.""" + tensor_list = _as_tensor_list(tensors) + with ops.name_scope(name, "shuffle_batch", + list(tensor_list) + [keep_input]) as name: + tensor_list = _validate(tensor_list) + keep_input = _validate_tensor_or_none(keep_input) + tensor_list, sparse_info = _store_sparse_tensors( + tensor_list, enqueue_many, keep_input) + types = _dtypes([tensor_list]) + shapes = _shapes([tensor_list], shapes, enqueue_many) + queue = data_flow_ops.RandomShuffleQueue( + capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, + dtypes=types, shapes=shapes, shared_name=shared_name) + _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input) + full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), + dtypes.float32) * + (1. / (capacity - min_after_dequeue))) + # Note that name contains a '/' at the end so we intentionally do not place + # a '/' after %s below. + summary_name = ( + "queue/%sfraction_over_%d_of_%d_full" % + (name, min_after_dequeue, capacity - min_after_dequeue)) + summary.scalar(summary_name, full) + + if allow_smaller_final_batch: + dequeued = queue.dequeue_up_to(batch_size, name=name) + else: + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _restore_sparse_tensors(dequeued, sparse_info) + return _as_original_type(tensors, dequeued) + + +def _shuffle_batch_join(tensors_list, batch_size, capacity, + min_after_dequeue, keep_input, seed=None, + enqueue_many=False, shapes=None, + allow_smaller_final_batch=False, shared_name=None, + name=None): + """Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`.""" + tensor_list_list = _as_tensor_list_list(tensors_list) + with ops.name_scope(name, "shuffle_batch_join", + _flatten(tensor_list_list) + [keep_input]) as name: + tensor_list_list = _validate_join(tensor_list_list) + keep_input = _validate_tensor_or_none(keep_input) + tensor_list_list, sparse_info = _store_sparse_tensors_join( + tensor_list_list, enqueue_many, keep_input) + types = _dtypes(tensor_list_list) + shapes = _shapes(tensor_list_list, shapes, enqueue_many) + queue = data_flow_ops.RandomShuffleQueue( + capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, + dtypes=types, shapes=shapes, shared_name=shared_name) + _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input) + full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), + dtypes.float32) * + (1. / (capacity - min_after_dequeue))) + # Note that name contains a '/' at the end so we intentionally do not place + # a '/' after %s below. + summary_name = ( + "queue/%sfraction_over_%d_of_%d_full" % + (name, min_after_dequeue, capacity - min_after_dequeue)) + summary.scalar(summary_name, full) + + if allow_smaller_final_batch: + dequeued = queue.dequeue_up_to(batch_size, name=name) + else: + dequeued = queue.dequeue_many(batch_size, name=name) + dequeued = _restore_sparse_tensors(dequeued, sparse_info) + # tensors_list was validated to not be empty. + return _as_original_type(tensors_list[0], dequeued) + # Batching functions ---------------------------------------------------------- @@ -671,35 +856,69 @@ def batch(tensors, batch_size, num_threads=1, capacity=32, ValueError: If the `shapes` are not specified, and cannot be inferred from the elements of `tensors`. """ - tensor_list = _as_tensor_list(tensors) - with ops.name_scope(name, "batch", tensor_list) as name: - tensor_list = _validate(tensor_list) - (tensor_list, sparse_info) = _store_sparse_tensors( - tensor_list, enqueue_many) - types = _dtypes([tensor_list]) - shapes = _shapes([tensor_list], shapes, enqueue_many) - # TODO(josh11b,mrry): Switch to BatchQueue once it is written. - queue = _which_queue(dynamic_pad)( - capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) - _enqueue(queue, tensor_list, num_threads, enqueue_many) - summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity), - math_ops.cast(queue.size(), dtypes.float32) * - (1. / capacity)) + return _batch( + tensors, + batch_size, + keep_input=True, + num_threads=num_threads, + capacity=capacity, + enqueue_many=enqueue_many, + shapes=shapes, + dynamic_pad=dynamic_pad, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) + + +def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32, + enqueue_many=False, shapes=None, dynamic_pad=False, + allow_smaller_final_batch=False, shared_name=None, name=None): + """Conditionally creates batches of tensors based on `keep_input`. + + See docstring in `batch` for more details. - if allow_smaller_final_batch: - dequeued = queue.dequeue_up_to(batch_size, name=name) - else: - dequeued = queue.dequeue_many(batch_size, name=name) - dequeued = _restore_sparse_tensors(dequeued, sparse_info) - return _as_original_type(tensors, dequeued) + Args: + tensors: The list or dictionary of tensors to enqueue. + keep_input: A `bool` scalar Tensor. This tensor controls whether the input + is added to the queue or not. If it evaluates `True`, then `tensors` are + added to the queue; otherwise they are dropped. This tensor essentially + acts as a filtering mechanism. + batch_size: The new batch size pulled from the queue. + num_threads: The number of threads enqueuing `tensors`. + capacity: An integer. The maximum number of elements in the queue. + enqueue_many: Whether each tensor in `tensors` is a single example. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensors`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batch to be smaller if there are insufficient items left in the queue. + shared_name: (Optional). If set, this queue will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A list or dictionary of tensors with the same types as `tensors`. + + Raises: + ValueError: If the `shapes` are not specified, and cannot be + inferred from the elements of `tensors`. + """ + return _batch( + tensors, + batch_size, + keep_input, + num_threads=num_threads, + capacity=capacity, + enqueue_many=enqueue_many, + shapes=shapes, + dynamic_pad=dynamic_pad, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) -# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be -# a multiple of len(tensor_list_list)?) parameter, to address the use -# case where you want more parallelism than you can support different -# readers (either because you don't have that many files or can't -# read that many files in parallel due to the number of seeks required). -# Once this is done, batch() can be written as a call to batch_join(). def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, shapes=None, dynamic_pad=False, allow_smaller_final_batch=False, shared_name=None, name=None): @@ -784,28 +1003,67 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False, ValueError: If the `shapes` are not specified, and cannot be inferred from the elements of `tensor_list_list`. """ - tensor_list_list = _as_tensor_list_list(tensors_list) - with ops.name_scope(name, "batch_join", _flatten(tensor_list_list)) as name: - tensor_list_list = _validate_join(tensor_list_list) - tensor_list_list, sparse_info = _store_sparse_tensors_join( - tensor_list_list, enqueue_many) - types = _dtypes(tensor_list_list) - shapes = _shapes(tensor_list_list, shapes, enqueue_many) - # TODO(josh11b,mrry): Switch to BatchQueue once it is written. - queue = _which_queue(dynamic_pad)( - capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) - _enqueue_join(queue, tensor_list_list, enqueue_many) - summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity), - math_ops.cast(queue.size(), dtypes.float32) * - (1. / capacity)) + return _batch_join( + tensors_list, + batch_size, + keep_input=True, + capacity=capacity, + enqueue_many=enqueue_many, + shapes=shapes, + dynamic_pad=dynamic_pad, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) + + +def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32, + enqueue_many=False, shapes=None, dynamic_pad=False, + allow_smaller_final_batch=False, shared_name=None, + name=None): + """Runs a list of tensors to conditionally fill a queue to create batches. + + See docstring in `batch_join` for more details. - if allow_smaller_final_batch: - dequeued = queue.dequeue_up_to(batch_size, name=name) - else: - dequeued = queue.dequeue_many(batch_size, name=name) - dequeued = _restore_sparse_tensors(dequeued, sparse_info) - # tensors_list was validated to not be empty. - return _as_original_type(tensors_list[0], dequeued) + Args: + tensors_list: A list of tuples or dictionaries of tensors to enqueue. + keep_input: A `bool` scalar Tensor. This tensor controls whether the input + is added to the queue or not. If it evaluates `True`, then `tensors` are + added to the queue; otherwise they are dropped. This tensor essentially + acts as a filtering mechanism. + batch_size: An integer. The new batch size pulled from the queue. + capacity: An integer. The maximum number of elements in the queue. + enqueue_many: Whether each tensor in `tensor_list_list` is a single + example. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensor_list_list[i]`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batch to be smaller if there are insufficient items left in the queue. + shared_name: (Optional) If set, this queue will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A list or dictionary of tensors with the same number and types as + `tensors_list[i]`. + + Raises: + ValueError: If the `shapes` are not specified, and cannot be + inferred from the elements of `tensor_list_list`. + """ + return _batch_join( + tensors_list, + batch_size, + keep_input, + capacity=capacity, + enqueue_many=enqueue_many, + shapes=shapes, + dynamic_pad=dynamic_pad, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, @@ -890,33 +1148,71 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, ValueError: If the `shapes` are not specified, and cannot be inferred from the elements of `tensors`. """ - tensor_list = _as_tensor_list(tensors) - with ops.name_scope(name, "shuffle_batch", tensor_list) as name: - tensor_list = _validate(tensor_list) - tensor_list, sparse_info = _store_sparse_tensors( - tensor_list, enqueue_many) - types = _dtypes([tensor_list]) - shapes = _shapes([tensor_list], shapes, enqueue_many) - queue = data_flow_ops.RandomShuffleQueue( - capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, - dtypes=types, shapes=shapes, shared_name=shared_name) - _enqueue(queue, tensor_list, num_threads, enqueue_many) - full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), - dtypes.float32) * - (1. / (capacity - min_after_dequeue))) - # Note that name contains a '/' at the end so we intentionally do not place - # a '/' after %s below. - summary_name = ( - "queue/%sfraction_over_%d_of_%d_full" % - (name, min_after_dequeue, capacity - min_after_dequeue)) - summary.scalar(summary_name, full) + return _shuffle_batch( + tensors, + batch_size, + capacity, + min_after_dequeue, + keep_input=True, + num_threads=num_threads, + seed=seed, + enqueue_many=enqueue_many, + shapes=shapes, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) + + +def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, + keep_input, num_threads=1, seed=None, + enqueue_many=False, shapes=None, + allow_smaller_final_batch=False, shared_name=None, + name=None): + """Creates batches by randomly shuffling conditionally-enqueued tensors. + + See docstring in `shuffle_batch` for more details. - if allow_smaller_final_batch: - dequeued = queue.dequeue_up_to(batch_size, name=name) - else: - dequeued = queue.dequeue_many(batch_size, name=name) - dequeued = _restore_sparse_tensors(dequeued, sparse_info) - return _as_original_type(tensors, dequeued) + Args: + tensors: The list or dictionary of tensors to enqueue. + batch_size: The new batch size pulled from the queue. + capacity: An integer. The maximum number of elements in the queue. + min_after_dequeue: Minimum number elements in the queue after a + dequeue, used to ensure a level of mixing of elements. + keep_input: A `bool` scalar Tensor. This tensor controls whether the input + is added to the queue or not. If it evaluates `True`, then `tensors` are + added to the queue; otherwise they are dropped. This tensor essentially + acts as a filtering mechanism. + num_threads: The number of threads enqueuing `tensor_list`. + seed: Seed for the random shuffling within the queue. + enqueue_many: Whether each tensor in `tensor_list` is a single example. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensor_list`. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batch to be smaller if there are insufficient items left in the queue. + shared_name: (Optional) If set, this queue will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A list or dictionary of tensors with the types as `tensors`. + + Raises: + ValueError: If the `shapes` are not specified, and cannot be + inferred from the elements of `tensors`. + """ + return _shuffle_batch( + tensors, + batch_size, + capacity, + min_after_dequeue, + keep_input, + num_threads=num_threads, + seed=seed, + enqueue_many=enqueue_many, + shapes=shapes, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) def shuffle_batch_join(tensors_list, batch_size, capacity, @@ -993,32 +1289,67 @@ def shuffle_batch_join(tensors_list, batch_size, capacity, ValueError: If the `shapes` are not specified, and cannot be inferred from the elements of `tensors_list`. """ - tensor_list_list = _as_tensor_list_list(tensors_list) - with ops.name_scope(name, "shuffle_batch_join", - _flatten(tensor_list_list)) as name: - tensor_list_list = _validate_join(tensor_list_list) - tensor_list_list, sparse_info = _store_sparse_tensors_join( - tensor_list_list, enqueue_many) - types = _dtypes(tensor_list_list) - shapes = _shapes(tensor_list_list, shapes, enqueue_many) - queue = data_flow_ops.RandomShuffleQueue( - capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed, - dtypes=types, shapes=shapes, shared_name=shared_name) - _enqueue_join(queue, tensor_list_list, enqueue_many) - full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue), - dtypes.float32) * - (1. / (capacity - min_after_dequeue))) - # Note that name contains a '/' at the end so we intentionally do not place - # a '/' after %s below. - summary_name = ( - "queue/%sfraction_over_%d_of_%d_full" % - (name, min_after_dequeue, capacity - min_after_dequeue)) - summary.scalar(summary_name, full) + return _shuffle_batch_join( + tensors_list, + batch_size, + capacity, + min_after_dequeue, + keep_input=True, + seed=seed, + enqueue_many=enqueue_many, + shapes=shapes, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) + + +def maybe_shuffle_batch_join(tensors_list, batch_size, capacity, + min_after_dequeue, keep_input, seed=None, + enqueue_many=False, shapes=None, + allow_smaller_final_batch=False, shared_name=None, + name=None): + """Create batches by randomly shuffling conditionally-enqueued tensors. + + See docstring in `shuffle_batch_join` for more details. - if allow_smaller_final_batch: - dequeued = queue.dequeue_up_to(batch_size, name=name) - else: - dequeued = queue.dequeue_many(batch_size, name=name) - dequeued = _restore_sparse_tensors(dequeued, sparse_info) - # tensors_list was validated to not be empty. - return _as_original_type(tensors_list[0], dequeued) + Args: + tensors_list: A list of tuples or dictionaries of tensors to enqueue. + batch_size: An integer. The new batch size pulled from the queue. + capacity: An integer. The maximum number of elements in the queue. + min_after_dequeue: Minimum number elements in the queue after a + dequeue, used to ensure a level of mixing of elements. + keep_input: A `bool` scalar Tensor. If provided, this tensor controls + whether the input is added to the queue or not. If it evaluates `True`, + then `tensors_list` are added to the queue; otherwise they are dropped. + This tensor essentially acts as a filtering mechanism. + seed: Seed for the random shuffling within the queue. + enqueue_many: Whether each tensor in `tensor_list_list` is a single + example. + shapes: (Optional) The shapes for each example. Defaults to the + inferred shapes for `tensors_list[i]`. + allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final + batch to be smaller if there are insufficient items left in the queue. + shared_name: (optional). If set, this queue will be shared under the given + name across multiple sessions. + name: (Optional) A name for the operations. + + Returns: + A list or dictionary of tensors with the same number and types as + `tensors_list[i]`. + + Raises: + ValueError: If the `shapes` are not specified, and cannot be + inferred from the elements of `tensors_list`. + """ + return _shuffle_batch_join( + tensors_list, + batch_size, + capacity, + min_after_dequeue, + keep_input, + seed=seed, + enqueue_many=enqueue_many, + shapes=shapes, + allow_smaller_final_batch=allow_smaller_final_batch, + shared_name=shared_name, + name=name) diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index 8f3470fc55..8087136c99 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -733,6 +733,83 @@ class BatchTest(tf.test.TestCase): x = tf.train.batch({"c": [12, 12]}, batch_size=8) self.assertAllEqual((8, 2), x["c"].get_shape().as_list()) + def _testKeepInputHelper(self, num_threads, enqueue_many): + with self.test_session() as sess: + batch_size = 5 + num_batches = 4 + examples = tf.Variable(0) + counter = examples.count_up_to(num_batches * batch_size * 2) + sparse_counter = tf.SparseTensor( + indices=tf.zeros([1, 1], dtype=tf.int64), + values=tf.stack([tf.cast(counter, tf.float32)]), + shape=[1]) + to_batch = [counter, sparse_counter, "string"] + if enqueue_many: + to_batch = tf.train.batch(to_batch, 1) + keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2))) + batched = tf.train.maybe_batch( + to_batch, keep_input, batch_size, num_threads=num_threads, + enqueue_many=enqueue_many) + tf.initialize_all_variables().run() + tf.initialize_local_variables().run() + threads = tf.train.start_queue_runners() + + for _ in range(num_batches): + results = sess.run(batched) + self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) + self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) + self.assertAllEqual([b"string"] * batch_size, results[2]) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + + def testSingleThreadKeepInput(self): + self._testKeepInputHelper(1, False) + + def testSingleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(1, True) + + def testMultipleThreadKeepInput(self): + self._testKeepInputHelper(5, False) + + def testMultipleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(5, True) + + def testMaybeBatchedSparseTensorInferredShape(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2) + self.assertAllEqual((2,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_batch( + [sparse], keep_input=True, batch_size=2, enqueue_many=True) + self.assertAllEqual((1,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2) + self.assertIs(None, batched.shape.get_shape().num_elements()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_batch( + [sparse], keep_input=True, batch_size=2, enqueue_many=True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + class BatchJoinTest(tf.test.TestCase): @@ -1125,6 +1202,85 @@ class BatchJoinTest(tf.test.TestCase): x = tf.train.batch_join([{"c": [12, 12]}], batch_size=8) self.assertAllEqual((8, 2), x["c"].get_shape().as_list()) + def _testKeepInputHelper(self, num_threads, enqueue_many): + with self.test_session() as sess: + batch_size = 5 + num_batches = 4 + examples = tf.Variable(0) + counter = examples.count_up_to(num_batches * batch_size * 2) + sparse_counter = tf.SparseTensor( + indices=tf.zeros([1, 1], dtype=tf.int64), + values=tf.stack([tf.cast(counter, tf.float32)]), + shape=[1]) + to_batch = [counter, sparse_counter, "string"] + if enqueue_many: + to_batch = tf.train.batch(to_batch, 1) + keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2))) + batched = tf.train.maybe_batch_join( + [to_batch] * num_threads, keep_input, batch_size, + enqueue_many=enqueue_many) + tf.initialize_all_variables().run() + tf.initialize_local_variables().run() + threads = tf.train.start_queue_runners() + + for _ in range(num_batches): + results = sess.run(batched) + self.assertAllEqual([0] * batch_size, np.mod(results[0], 2),) + self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2),) + self.assertAllEqual([b"string"] * batch_size, results[2]) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + + def testSingleThreadKeepInput(self): + self._testKeepInputHelper(1, False) + + def testSingleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(1, True) + + def testMultipleThreadKeepInput(self): + self._testKeepInputHelper(5, False) + + def testMultipleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(5, True) + + def testMaybeBatchedSparseTensorInferredShape(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_batch_join( + [[sparse]], keep_input=True, batch_size=2) + self.assertAllEqual((2,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_batch_join( + [[sparse]], keep_input=True, batch_size=2, enqueue_many=True) + self.assertAllEqual((1,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_batch_join( + [[sparse]], keep_input=True, batch_size=2) + self.assertIs(None, batched.shape.get_shape().num_elements()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_batch_join( + [[sparse]], keep_input=True, batch_size=2, enqueue_many=True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + class ShuffleBatchTest(tf.test.TestCase): @@ -1351,6 +1507,83 @@ class ShuffleBatchTest(tf.test.TestCase): "s: 'SHARED_NAME_XYZ'", batched[0].op.inputs[0].op.node_def.attr["shared_name"]) + def _testKeepInputHelper(self, num_threads, enqueue_many): + with self.test_session() as sess: + batch_size = 5 + num_batches = 4 + examples = tf.Variable(0) + counter = examples.count_up_to(num_batches * batch_size * 2) + sparse_counter = tf.SparseTensor( + indices=tf.zeros([1, 1], dtype=tf.int64), + values=tf.stack([tf.cast(counter, tf.float32)]), + shape=[1]) + to_batch = [counter, sparse_counter, "string"] + if enqueue_many: + to_batch = tf.train.batch(to_batch, 1) + keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2))) + batched = tf.train.maybe_shuffle_batch( + to_batch, batch_size, 10, 1, keep_input, num_threads=num_threads, + enqueue_many=enqueue_many) + tf.initialize_all_variables().run() + tf.initialize_local_variables().run() + threads = tf.train.start_queue_runners() + + for _ in range(num_batches): + results = sess.run(batched) + self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) + self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) + self.assertAllEqual([b"string"] * batch_size, results[2]) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + + def testSingleThreadKeepInput(self): + self._testKeepInputHelper(1, False) + + def testSingleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(1, True) + + def testMultipleThreadKeepInput(self): + self._testKeepInputHelper(5, False) + + def testMultipleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(5, True) + + def testMaybeBatchedSparseTensorInferredShape(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True) + self.assertAllEqual((2,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_shuffle_batch( + [sparse], 2, 10, 1, True, enqueue_many=True) + self.assertAllEqual((1,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_shuffle_batch( + [sparse], 2, 10, 1, True, enqueue_many=True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + class ShuffleBatchJoinTest(tf.test.TestCase): @@ -1581,6 +1814,83 @@ class ShuffleBatchJoinTest(tf.test.TestCase): "s: 'SHARED_NAME_XYZ'", batched[0].op.inputs[0].op.node_def.attr["shared_name"]) + def _testKeepInputHelper(self, num_threads, enqueue_many): + with self.test_session() as sess: + batch_size = 5 + num_batches = 4 + examples = tf.Variable(0) + counter = examples.count_up_to(num_batches * batch_size * 2) + sparse_counter = tf.SparseTensor( + indices=tf.zeros([1, 1], dtype=tf.int64), + values=tf.stack([tf.cast(counter, tf.float32)]), + shape=[1]) + to_batch = [counter, sparse_counter, "string"] + if enqueue_many: + to_batch = tf.train.batch(to_batch, 1) + keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2))) + batched = tf.train.maybe_shuffle_batch_join( + [to_batch] * num_threads, batch_size, 10, 1, keep_input, + enqueue_many=enqueue_many) + tf.initialize_all_variables().run() + tf.initialize_local_variables().run() + threads = tf.train.start_queue_runners() + + for _ in range(num_batches): + results = sess.run(batched) + self.assertAllEqual([0] * batch_size, np.mod(results[0], 2)) + self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2)) + self.assertAllEqual([b"string"] * batch_size, results[2]) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + + def testSingleThreadKeepInput(self): + self._testKeepInputHelper(1, False) + + def testSingleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(1, True) + + def testMultipleThreadKeepInput(self): + self._testKeepInputHelper(5, False) + + def testMultipleThreadKeepInputEnqueueMany(self): + self._testKeepInputHelper(5, True) + + def testMaybeBatchedSparseTensorInferredShape(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True) + self.assertAllEqual((2,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self): + sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1]) + self.assertAllEqual((1,), sparse.shape.get_shape().as_list()) + batched = tf.train.maybe_shuffle_batch_join( + [[sparse]], 2, 10, 1, True, enqueue_many=True) + self.assertAllEqual((1,), batched.shape.get_shape().as_list()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + + def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self): + sparse = tf.SparseTensor( + indices=tf.placeholder(tf.int64), + values=tf.placeholder(tf.float32), + shape=tf.placeholder(tf.int64)) + self.assertIs(None, sparse.shape.get_shape().num_elements()) + batched = tf.train.maybe_shuffle_batch_join( + [[sparse]], 2, 10, 1, True, enqueue_many=True) + self.assertIs(None, batched.shape.get_shape().num_elements()) + if __name__ == "__main__": tf.test.main() -- cgit v1.2.3