aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-01 14:24:51 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 15:11:07 -0800
commit6b1c2cc8306322976d0738f4b799760dce29b23b (patch)
treed135115373c3b2b984f0234bd2df5a6e1d018d62
parentc169af2a452a7908aea664d5297c4f7c88811eb6 (diff)
Add ability to conditionally batch in 'tf.train.batch`.
Change: 140778005
-rw-r--r--tensorflow/contrib/training/python/training/bucket_ops.py2
-rw-r--r--tensorflow/python/ops/io_ops.py6
-rw-r--r--tensorflow/python/training/input.py567
-rw-r--r--tensorflow/python/training/input_test.py310
4 files changed, 765 insertions, 120 deletions
diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py
index 3f397d2401..199764e6b9 100644
--- a/tensorflow/contrib/training/python/training/bucket_ops.py
+++ b/tensorflow/contrib/training/python/training/bucket_ops.py
@@ -152,7 +152,7 @@ def bucket(tensors,
with ops.name_scope(name, "bucket", tensor_list) as name:
tensor_list = _validate_bucket(tensor_list)
(tensor_list, sparse_info) = _store_sparse_tensors(
- tensor_list, enqueue_many=False)
+ tensor_list, enqueue_many=False, keep_input=True)
# Round-trip batch_size to a tensor, and possibly back
batch_size = ops.convert_to_tensor(
diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py
index 99f992ff5f..e7af6bfe2d 100644
--- a/tensorflow/python/ops/io_ops.py
+++ b/tensorflow/python/ops/io_ops.py
@@ -131,12 +131,16 @@ single subgraph producing examples but you want to run it in *N* threads
(where you increase *N* until it can keep the queue full). Use
[`batch_join`](#batch_join) or [`shuffle_batch_join`](#shuffle_batch_join)
if you have *N* different subgraphs producing examples to batch and you
-want them run by *N* threads.
+want them run by *N* threads. Use `maybe_*` to enqueue conditionally.
@@batch
+@@maybe_batch
@@batch_join
+@@maybe_batch_join
@@shuffle_batch
+@@maybe_shuffle_batch
@@shuffle_batch_join
+@@maybe_shuffle_batch_join
"""
from __future__ import absolute_import
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 4fdda96860..67dc66ea4c 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -407,7 +407,8 @@ def _as_original_type(original_tensors, tensor_list):
return tensor_list
-def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None):
+def _store_sparse_tensors(tensor_list, enqueue_many, keep_input,
+ shared_map_ops=None):
"""Store SparseTensors for feeding into batch, etc.
If `shared_map_ops` is provided, the underlying `SparseTensorsMap` objects
@@ -425,6 +426,7 @@ def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None):
Args:
tensor_list: List of `Tensor` and `SparseTensor` objects.
enqueue_many: Python `Boolean`.
+ keep_input: Bool tensor. If False, don't store.
shared_map_ops: (optional) List of `Operation` objects from a previous
call to `_store_sparse_tensors`. If not `None`, the op types should be
one of `AddSparseToTensorsMap` or `AddManySparseToTensorsMap` in the
@@ -443,40 +445,64 @@ def _store_sparse_tensors(tensor_list, enqueue_many, shared_map_ops=None):
rank = t.shape.get_shape().with_rank(1)[0]
if enqueue_many:
rank -= 1
- # If a shared map_op was provided, use that. Otherwise use the name of
+ # If a shared map_op was provided, use that. Otherwise use the name of
# the operation used to store the SparseTensor.
return _SparseMetaData(
sparse=True, map_op=map_op or storing_op, rank=rank)
def _maybe_store(t, shared_map_op):
+ """Store Sparse tensor, if necessary."""
if not isinstance(t, sparse_tensor.SparseTensor):
return t
map_op_name = shared_map_op.name if shared_map_op else None
- return (_store_many_sparse(t, shared_name=map_op_name) if enqueue_many
- else _store_sparse(t, shared_name=map_op_name))
+ def _maybe_store_sparse(t, map_op_name, keep_input):
+ return control_flow_ops.cond(
+ keep_input,
+ lambda: _store_sparse(t, shared_name=map_op_name),
+ lambda: constant_op.constant(-1, dtypes.int64))
+ def _maybe_store_many_sparse(t, map_op_name, keep_input):
+ out_tensor = control_flow_ops.cond(
+ keep_input,
+ lambda: _store_many_sparse(t, shared_name=map_op_name),
+ lambda: -1 * array_ops.ones(array_ops.shape(t)[0:1], dtypes.int64))
+ out_tensor.set_shape([None]) # necessary when t.ndims is unknown
+ return out_tensor
+ store_f = _maybe_store_many_sparse if enqueue_many else _maybe_store_sparse
+ return store_f(t, map_op_name, keep_input)
stored_list = [
_maybe_store(t, shared_map_op) for t, shared_map_op
in zip(tensor_list, maybe_shared_map_ops)]
+ # Since the output of `_store{_many}_sparse is wrapped in a tf.cond `Merge`,
+ # we can't just get the Op of the resulting tensor.
+ def _sparse_op(stored):
+ for input_tensor in stored.op.inputs:
+ if input_tensor.op.type in ("AddSparseToTensorsMap",
+ "AddManySparseToTensorsMap"):
+ return input_tensor.op
+ # If there was no sparse input, then the original stored Tensor wasn't
+ # sparse and we can just return the original Tensor's Op.
+ return stored.op
sparse_info_list = [
- _sparse_meta_data(t, stored.op, shared_map_op)
+ _sparse_meta_data(t, _sparse_op(stored), shared_map_op)
for t, stored, shared_map_op
in zip(tensor_list, stored_list, maybe_shared_map_ops)]
- # expand dims of stored tensors by 1 for proper enqueue shape
+ # Expand dims of stored tensors by 1 for proper enqueue shape
stored_list = [
array_ops.expand_dims(s, [-1]) if s_info.sparse else s
for s, s_info in zip(stored_list, sparse_info_list)]
return stored_list, sparse_info_list
-def _store_sparse_tensors_join(tensor_list_list, enqueue_many):
+def _store_sparse_tensors_join(tensor_list_list, enqueue_many, keep_input):
"""Store SparseTensors for feeding into batch_join, etc."""
(s0, sparse_info_list) = _store_sparse_tensors(
- tensor_list_list[0], enqueue_many)
+ tensor_list_list[0], enqueue_many, keep_input)
stored_list_list = [s0]
for tensor_list in tensor_list_list[1:]:
s, sparse_info_candidate = _store_sparse_tensors(
- tensor_list, enqueue_many, [st.map_op for st in sparse_info_list])
+ tensor_list, enqueue_many, keep_input,
+ [st.map_op for st in sparse_info_list])
if sparse_info_list != sparse_info_candidate:
raise ValueError("Inconsistent SparseTensors list: %s vs. %s"
% (tensor_list_list[0], tensor_list))
@@ -498,8 +524,7 @@ def _restore_sparse_tensors(stored_list, sparse_info_list):
sparse_handles=array_ops.squeeze(s, [1]),
rank=(info.rank + 1).value)
if info.sparse else s
- for (s, info)
- in zip(stored_list, sparse_info_list)]
+ for (s, info) in zip(stored_list, sparse_info_list)]
return tensors if received_sequence else tensors[0]
@@ -518,6 +543,12 @@ def _validate_join(tensor_list_list):
return tensor_list_list
+def _validate_tensor_or_none(tensor_or_none):
+ if tensor_or_none is not None:
+ return ops.convert_to_tensor(tensor_or_none)
+ return tensor_or_none
+
+
def _dtypes(tensor_list_list):
all_types = [[t.dtype for t in tl] for tl in tensor_list_list]
types = all_types[0]
@@ -571,19 +602,35 @@ def _shapes(tensor_list_list, shapes, enqueue_many):
return shapes
-def _enqueue_join(queue, tensor_list_list, enqueue_many):
+def _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input):
+ """Enqueue `tensor_list_list` in `queue`."""
if enqueue_many:
- enqueue_ops = [queue.enqueue_many(tl) for tl in tensor_list_list]
+ enqueue_fn = queue.enqueue_many
+ else:
+ enqueue_fn = queue.enqueue
+ if keep_input is None:
+ enqueue_ops = [enqueue_fn(tl) for tl in tensor_list_list]
else:
- enqueue_ops = [queue.enqueue(tl) for tl in tensor_list_list]
+ enqueue_ops = [control_flow_ops.cond(
+ keep_input,
+ lambda: enqueue_fn(tl),
+ control_flow_ops.no_op) for tl in tensor_list_list]
queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
-def _enqueue(queue, tensor_list, threads, enqueue_many):
+def _enqueue(queue, tensor_list, threads, enqueue_many, keep_input):
+ """Enqueue `tensor_list` in `queue`."""
if enqueue_many:
- enqueue_ops = [queue.enqueue_many(tensor_list)] * threads
+ enqueue_fn = queue.enqueue_many
else:
- enqueue_ops = [queue.enqueue(tensor_list)] * threads
+ enqueue_fn = queue.enqueue
+ if keep_input is None:
+ enqueue_ops = [enqueue_fn(tensor_list)] * threads
+ else:
+ enqueue_ops = [control_flow_ops.cond(
+ keep_input,
+ lambda: enqueue_fn(tensor_list),
+ control_flow_ops.no_op)] * threads
queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
@@ -592,6 +639,144 @@ def _which_queue(dynamic_pad):
else data_flow_ops.FIFOQueue)
+def _batch(tensors, batch_size, keep_input, num_threads=1, capacity=32,
+ enqueue_many=False, shapes=None, dynamic_pad=False,
+ allow_smaller_final_batch=False, shared_name=None,
+ name=None):
+ """Helper function for `batch` and `maybe_batch`."""
+ tensor_list = _as_tensor_list(tensors)
+ with ops.name_scope(name, "batch", list(tensor_list) + [keep_input]) as name:
+ tensor_list = _validate(tensor_list)
+ keep_input = _validate_tensor_or_none(keep_input)
+ (tensor_list, sparse_info) = _store_sparse_tensors(
+ tensor_list, enqueue_many, keep_input)
+ types = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = _which_queue(dynamic_pad)(
+ capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
+ summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), dtypes.float32) *
+ (1. / capacity))
+
+ if allow_smaller_final_batch:
+ dequeued = queue.dequeue_up_to(batch_size, name=name)
+ else:
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _restore_sparse_tensors(dequeued, sparse_info)
+ return _as_original_type(tensors, dequeued)
+
+
+# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
+# a multiple of len(tensor_list_list)?) parameter, to address the use
+# case where you want more parallelism than you can support different
+# readers (either because you don't have that many files or can't
+# read that many files in parallel due to the number of seeks required).
+# Once this is done, batch() can be written as a call to batch_join().
+def _batch_join(tensors_list, batch_size, keep_input, capacity=32,
+ enqueue_many=False, shapes=None, dynamic_pad=False,
+ allow_smaller_final_batch=False, shared_name=None, name=None):
+ """Helper function for `batch_join` and `maybe_batch_join`."""
+ tensor_list_list = _as_tensor_list_list(tensors_list)
+ with ops.name_scope(name, "batch_join",
+ _flatten(tensor_list_list) + [keep_input]) as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ keep_input = _validate_tensor_or_none(keep_input)
+ tensor_list_list, sparse_info = _store_sparse_tensors_join(
+ tensor_list_list, enqueue_many, keep_input)
+ types = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
+ queue = _which_queue(dynamic_pad)(
+ capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
+ _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
+ summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity),
+ math_ops.cast(queue.size(), dtypes.float32) *
+ (1. / capacity))
+
+ if allow_smaller_final_batch:
+ dequeued = queue.dequeue_up_to(batch_size, name=name)
+ else:
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _restore_sparse_tensors(dequeued, sparse_info)
+ # tensors_list was validated to not be empty.
+ return _as_original_type(tensors_list[0], dequeued)
+
+
+def _shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
+ keep_input, num_threads=1, seed=None, enqueue_many=False,
+ shapes=None, allow_smaller_final_batch=False,
+ shared_name=None, name=None):
+ """Helper function for `shuffle_batch` and `maybe_shuffle_batch`."""
+ tensor_list = _as_tensor_list(tensors)
+ with ops.name_scope(name, "shuffle_batch",
+ list(tensor_list) + [keep_input]) as name:
+ tensor_list = _validate(tensor_list)
+ keep_input = _validate_tensor_or_none(keep_input)
+ tensor_list, sparse_info = _store_sparse_tensors(
+ tensor_list, enqueue_many, keep_input)
+ types = _dtypes([tensor_list])
+ shapes = _shapes([tensor_list], shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=types, shapes=shapes, shared_name=shared_name)
+ _enqueue(queue, tensor_list, num_threads, enqueue_many, keep_input)
+ full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
+ dtypes.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary.scalar(summary_name, full)
+
+ if allow_smaller_final_batch:
+ dequeued = queue.dequeue_up_to(batch_size, name=name)
+ else:
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _restore_sparse_tensors(dequeued, sparse_info)
+ return _as_original_type(tensors, dequeued)
+
+
+def _shuffle_batch_join(tensors_list, batch_size, capacity,
+ min_after_dequeue, keep_input, seed=None,
+ enqueue_many=False, shapes=None,
+ allow_smaller_final_batch=False, shared_name=None,
+ name=None):
+ """Helper function for `shuffle_batch_join` and `maybe_shuffle_batch_join`."""
+ tensor_list_list = _as_tensor_list_list(tensors_list)
+ with ops.name_scope(name, "shuffle_batch_join",
+ _flatten(tensor_list_list) + [keep_input]) as name:
+ tensor_list_list = _validate_join(tensor_list_list)
+ keep_input = _validate_tensor_or_none(keep_input)
+ tensor_list_list, sparse_info = _store_sparse_tensors_join(
+ tensor_list_list, enqueue_many, keep_input)
+ types = _dtypes(tensor_list_list)
+ shapes = _shapes(tensor_list_list, shapes, enqueue_many)
+ queue = data_flow_ops.RandomShuffleQueue(
+ capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
+ dtypes=types, shapes=shapes, shared_name=shared_name)
+ _enqueue_join(queue, tensor_list_list, enqueue_many, keep_input)
+ full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
+ dtypes.float32) *
+ (1. / (capacity - min_after_dequeue)))
+ # Note that name contains a '/' at the end so we intentionally do not place
+ # a '/' after %s below.
+ summary_name = (
+ "queue/%sfraction_over_%d_of_%d_full" %
+ (name, min_after_dequeue, capacity - min_after_dequeue))
+ summary.scalar(summary_name, full)
+
+ if allow_smaller_final_batch:
+ dequeued = queue.dequeue_up_to(batch_size, name=name)
+ else:
+ dequeued = queue.dequeue_many(batch_size, name=name)
+ dequeued = _restore_sparse_tensors(dequeued, sparse_info)
+ # tensors_list was validated to not be empty.
+ return _as_original_type(tensors_list[0], dequeued)
+
# Batching functions ----------------------------------------------------------
@@ -671,35 +856,69 @@ def batch(tensors, batch_size, num_threads=1, capacity=32,
ValueError: If the `shapes` are not specified, and cannot be
inferred from the elements of `tensors`.
"""
- tensor_list = _as_tensor_list(tensors)
- with ops.name_scope(name, "batch", tensor_list) as name:
- tensor_list = _validate(tensor_list)
- (tensor_list, sparse_info) = _store_sparse_tensors(
- tensor_list, enqueue_many)
- types = _dtypes([tensor_list])
- shapes = _shapes([tensor_list], shapes, enqueue_many)
- # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
- queue = _which_queue(dynamic_pad)(
- capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
- _enqueue(queue, tensor_list, num_threads, enqueue_many)
- summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity),
- math_ops.cast(queue.size(), dtypes.float32) *
- (1. / capacity))
+ return _batch(
+ tensors,
+ batch_size,
+ keep_input=True,
+ num_threads=num_threads,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ dynamic_pad=dynamic_pad,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
+
+
+def maybe_batch(tensors, keep_input, batch_size, num_threads=1, capacity=32,
+ enqueue_many=False, shapes=None, dynamic_pad=False,
+ allow_smaller_final_batch=False, shared_name=None, name=None):
+ """Conditionally creates batches of tensors based on `keep_input`.
+
+ See docstring in `batch` for more details.
- if allow_smaller_final_batch:
- dequeued = queue.dequeue_up_to(batch_size, name=name)
- else:
- dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _restore_sparse_tensors(dequeued, sparse_info)
- return _as_original_type(tensors, dequeued)
+ Args:
+ tensors: The list or dictionary of tensors to enqueue.
+ keep_input: A `bool` scalar Tensor. This tensor controls whether the input
+ is added to the queue or not. If it evaluates `True`, then `tensors` are
+ added to the queue; otherwise they are dropped. This tensor essentially
+ acts as a filtering mechanism.
+ batch_size: The new batch size pulled from the queue.
+ num_threads: The number of threads enqueuing `tensors`.
+ capacity: An integer. The maximum number of elements in the queue.
+ enqueue_many: Whether each tensor in `tensors` is a single example.
+ shapes: (Optional) The shapes for each example. Defaults to the
+ inferred shapes for `tensors`.
+ dynamic_pad: Boolean. Allow variable dimensions in input shapes.
+ The given dimensions are padded upon dequeue so that tensors within a
+ batch have the same shapes.
+ allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
+ batch to be smaller if there are insufficient items left in the queue.
+ shared_name: (Optional). If set, this queue will be shared under the given
+ name across multiple sessions.
+ name: (Optional) A name for the operations.
+
+ Returns:
+ A list or dictionary of tensors with the same types as `tensors`.
+
+ Raises:
+ ValueError: If the `shapes` are not specified, and cannot be
+ inferred from the elements of `tensors`.
+ """
+ return _batch(
+ tensors,
+ batch_size,
+ keep_input,
+ num_threads=num_threads,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ dynamic_pad=dynamic_pad,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
-# TODO(josh11b): Add a thread_multiplier or num_threads (that has to be
-# a multiple of len(tensor_list_list)?) parameter, to address the use
-# case where you want more parallelism than you can support different
-# readers (either because you don't have that many files or can't
-# read that many files in parallel due to the number of seeks required).
-# Once this is done, batch() can be written as a call to batch_join().
def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
shapes=None, dynamic_pad=False, allow_smaller_final_batch=False,
shared_name=None, name=None):
@@ -784,28 +1003,67 @@ def batch_join(tensors_list, batch_size, capacity=32, enqueue_many=False,
ValueError: If the `shapes` are not specified, and cannot be
inferred from the elements of `tensor_list_list`.
"""
- tensor_list_list = _as_tensor_list_list(tensors_list)
- with ops.name_scope(name, "batch_join", _flatten(tensor_list_list)) as name:
- tensor_list_list = _validate_join(tensor_list_list)
- tensor_list_list, sparse_info = _store_sparse_tensors_join(
- tensor_list_list, enqueue_many)
- types = _dtypes(tensor_list_list)
- shapes = _shapes(tensor_list_list, shapes, enqueue_many)
- # TODO(josh11b,mrry): Switch to BatchQueue once it is written.
- queue = _which_queue(dynamic_pad)(
- capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
- _enqueue_join(queue, tensor_list_list, enqueue_many)
- summary.scalar("queue/%s/fraction_of_%d_full" % (queue.name, capacity),
- math_ops.cast(queue.size(), dtypes.float32) *
- (1. / capacity))
+ return _batch_join(
+ tensors_list,
+ batch_size,
+ keep_input=True,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ dynamic_pad=dynamic_pad,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
+
+
+def maybe_batch_join(tensors_list, keep_input, batch_size, capacity=32,
+ enqueue_many=False, shapes=None, dynamic_pad=False,
+ allow_smaller_final_batch=False, shared_name=None,
+ name=None):
+ """Runs a list of tensors to conditionally fill a queue to create batches.
+
+ See docstring in `batch_join` for more details.
- if allow_smaller_final_batch:
- dequeued = queue.dequeue_up_to(batch_size, name=name)
- else:
- dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _restore_sparse_tensors(dequeued, sparse_info)
- # tensors_list was validated to not be empty.
- return _as_original_type(tensors_list[0], dequeued)
+ Args:
+ tensors_list: A list of tuples or dictionaries of tensors to enqueue.
+ keep_input: A `bool` scalar Tensor. This tensor controls whether the input
+ is added to the queue or not. If it evaluates `True`, then `tensors` are
+ added to the queue; otherwise they are dropped. This tensor essentially
+ acts as a filtering mechanism.
+ batch_size: An integer. The new batch size pulled from the queue.
+ capacity: An integer. The maximum number of elements in the queue.
+ enqueue_many: Whether each tensor in `tensor_list_list` is a single
+ example.
+ shapes: (Optional) The shapes for each example. Defaults to the
+ inferred shapes for `tensor_list_list[i]`.
+ dynamic_pad: Boolean. Allow variable dimensions in input shapes.
+ The given dimensions are padded upon dequeue so that tensors within a
+ batch have the same shapes.
+ allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
+ batch to be smaller if there are insufficient items left in the queue.
+ shared_name: (Optional) If set, this queue will be shared under the given
+ name across multiple sessions.
+ name: (Optional) A name for the operations.
+
+ Returns:
+ A list or dictionary of tensors with the same number and types as
+ `tensors_list[i]`.
+
+ Raises:
+ ValueError: If the `shapes` are not specified, and cannot be
+ inferred from the elements of `tensor_list_list`.
+ """
+ return _batch_join(
+ tensors_list,
+ batch_size,
+ keep_input,
+ capacity=capacity,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ dynamic_pad=dynamic_pad,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
@@ -890,33 +1148,71 @@ def shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
ValueError: If the `shapes` are not specified, and cannot be
inferred from the elements of `tensors`.
"""
- tensor_list = _as_tensor_list(tensors)
- with ops.name_scope(name, "shuffle_batch", tensor_list) as name:
- tensor_list = _validate(tensor_list)
- tensor_list, sparse_info = _store_sparse_tensors(
- tensor_list, enqueue_many)
- types = _dtypes([tensor_list])
- shapes = _shapes([tensor_list], shapes, enqueue_many)
- queue = data_flow_ops.RandomShuffleQueue(
- capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
- dtypes=types, shapes=shapes, shared_name=shared_name)
- _enqueue(queue, tensor_list, num_threads, enqueue_many)
- full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
- dtypes.float32) *
- (1. / (capacity - min_after_dequeue)))
- # Note that name contains a '/' at the end so we intentionally do not place
- # a '/' after %s below.
- summary_name = (
- "queue/%sfraction_over_%d_of_%d_full" %
- (name, min_after_dequeue, capacity - min_after_dequeue))
- summary.scalar(summary_name, full)
+ return _shuffle_batch(
+ tensors,
+ batch_size,
+ capacity,
+ min_after_dequeue,
+ keep_input=True,
+ num_threads=num_threads,
+ seed=seed,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
+
+
+def maybe_shuffle_batch(tensors, batch_size, capacity, min_after_dequeue,
+ keep_input, num_threads=1, seed=None,
+ enqueue_many=False, shapes=None,
+ allow_smaller_final_batch=False, shared_name=None,
+ name=None):
+ """Creates batches by randomly shuffling conditionally-enqueued tensors.
+
+ See docstring in `shuffle_batch` for more details.
- if allow_smaller_final_batch:
- dequeued = queue.dequeue_up_to(batch_size, name=name)
- else:
- dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _restore_sparse_tensors(dequeued, sparse_info)
- return _as_original_type(tensors, dequeued)
+ Args:
+ tensors: The list or dictionary of tensors to enqueue.
+ batch_size: The new batch size pulled from the queue.
+ capacity: An integer. The maximum number of elements in the queue.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ keep_input: A `bool` scalar Tensor. This tensor controls whether the input
+ is added to the queue or not. If it evaluates `True`, then `tensors` are
+ added to the queue; otherwise they are dropped. This tensor essentially
+ acts as a filtering mechanism.
+ num_threads: The number of threads enqueuing `tensor_list`.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: Whether each tensor in `tensor_list` is a single example.
+ shapes: (Optional) The shapes for each example. Defaults to the
+ inferred shapes for `tensor_list`.
+ allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
+ batch to be smaller if there are insufficient items left in the queue.
+ shared_name: (Optional) If set, this queue will be shared under the given
+ name across multiple sessions.
+ name: (Optional) A name for the operations.
+
+ Returns:
+ A list or dictionary of tensors with the types as `tensors`.
+
+ Raises:
+ ValueError: If the `shapes` are not specified, and cannot be
+ inferred from the elements of `tensors`.
+ """
+ return _shuffle_batch(
+ tensors,
+ batch_size,
+ capacity,
+ min_after_dequeue,
+ keep_input,
+ num_threads=num_threads,
+ seed=seed,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
def shuffle_batch_join(tensors_list, batch_size, capacity,
@@ -993,32 +1289,67 @@ def shuffle_batch_join(tensors_list, batch_size, capacity,
ValueError: If the `shapes` are not specified, and cannot be
inferred from the elements of `tensors_list`.
"""
- tensor_list_list = _as_tensor_list_list(tensors_list)
- with ops.name_scope(name, "shuffle_batch_join",
- _flatten(tensor_list_list)) as name:
- tensor_list_list = _validate_join(tensor_list_list)
- tensor_list_list, sparse_info = _store_sparse_tensors_join(
- tensor_list_list, enqueue_many)
- types = _dtypes(tensor_list_list)
- shapes = _shapes(tensor_list_list, shapes, enqueue_many)
- queue = data_flow_ops.RandomShuffleQueue(
- capacity=capacity, min_after_dequeue=min_after_dequeue, seed=seed,
- dtypes=types, shapes=shapes, shared_name=shared_name)
- _enqueue_join(queue, tensor_list_list, enqueue_many)
- full = (math_ops.cast(math_ops.maximum(0, queue.size() - min_after_dequeue),
- dtypes.float32) *
- (1. / (capacity - min_after_dequeue)))
- # Note that name contains a '/' at the end so we intentionally do not place
- # a '/' after %s below.
- summary_name = (
- "queue/%sfraction_over_%d_of_%d_full" %
- (name, min_after_dequeue, capacity - min_after_dequeue))
- summary.scalar(summary_name, full)
+ return _shuffle_batch_join(
+ tensors_list,
+ batch_size,
+ capacity,
+ min_after_dequeue,
+ keep_input=True,
+ seed=seed,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
+
+
+def maybe_shuffle_batch_join(tensors_list, batch_size, capacity,
+ min_after_dequeue, keep_input, seed=None,
+ enqueue_many=False, shapes=None,
+ allow_smaller_final_batch=False, shared_name=None,
+ name=None):
+ """Create batches by randomly shuffling conditionally-enqueued tensors.
+
+ See docstring in `shuffle_batch_join` for more details.
- if allow_smaller_final_batch:
- dequeued = queue.dequeue_up_to(batch_size, name=name)
- else:
- dequeued = queue.dequeue_many(batch_size, name=name)
- dequeued = _restore_sparse_tensors(dequeued, sparse_info)
- # tensors_list was validated to not be empty.
- return _as_original_type(tensors_list[0], dequeued)
+ Args:
+ tensors_list: A list of tuples or dictionaries of tensors to enqueue.
+ batch_size: An integer. The new batch size pulled from the queue.
+ capacity: An integer. The maximum number of elements in the queue.
+ min_after_dequeue: Minimum number elements in the queue after a
+ dequeue, used to ensure a level of mixing of elements.
+ keep_input: A `bool` scalar Tensor. If provided, this tensor controls
+ whether the input is added to the queue or not. If it evaluates `True`,
+ then `tensors_list` are added to the queue; otherwise they are dropped.
+ This tensor essentially acts as a filtering mechanism.
+ seed: Seed for the random shuffling within the queue.
+ enqueue_many: Whether each tensor in `tensor_list_list` is a single
+ example.
+ shapes: (Optional) The shapes for each example. Defaults to the
+ inferred shapes for `tensors_list[i]`.
+ allow_smaller_final_batch: (Optional) Boolean. If `True`, allow the final
+ batch to be smaller if there are insufficient items left in the queue.
+ shared_name: (optional). If set, this queue will be shared under the given
+ name across multiple sessions.
+ name: (Optional) A name for the operations.
+
+ Returns:
+ A list or dictionary of tensors with the same number and types as
+ `tensors_list[i]`.
+
+ Raises:
+ ValueError: If the `shapes` are not specified, and cannot be
+ inferred from the elements of `tensors_list`.
+ """
+ return _shuffle_batch_join(
+ tensors_list,
+ batch_size,
+ capacity,
+ min_after_dequeue,
+ keep_input,
+ seed=seed,
+ enqueue_many=enqueue_many,
+ shapes=shapes,
+ allow_smaller_final_batch=allow_smaller_final_batch,
+ shared_name=shared_name,
+ name=name)
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index 8f3470fc55..8087136c99 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -733,6 +733,83 @@ class BatchTest(tf.test.TestCase):
x = tf.train.batch({"c": [12, 12]}, batch_size=8)
self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
+ def _testKeepInputHelper(self, num_threads, enqueue_many):
+ with self.test_session() as sess:
+ batch_size = 5
+ num_batches = 4
+ examples = tf.Variable(0)
+ counter = examples.count_up_to(num_batches * batch_size * 2)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.zeros([1, 1], dtype=tf.int64),
+ values=tf.stack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ to_batch = [counter, sparse_counter, "string"]
+ if enqueue_many:
+ to_batch = tf.train.batch(to_batch, 1)
+ keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2)))
+ batched = tf.train.maybe_batch(
+ to_batch, keep_input, batch_size, num_threads=num_threads,
+ enqueue_many=enqueue_many)
+ tf.initialize_all_variables().run()
+ tf.initialize_local_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for _ in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
+ self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
+ self.assertAllEqual([b"string"] * batch_size, results[2])
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testSingleThreadKeepInput(self):
+ self._testKeepInputHelper(1, False)
+
+ def testSingleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(1, True)
+
+ def testMultipleThreadKeepInput(self):
+ self._testKeepInputHelper(5, False)
+
+ def testMultipleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(5, True)
+
+ def testMaybeBatchedSparseTensorInferredShape(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2)
+ self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_batch(
+ [sparse], keep_input=True, batch_size=2, enqueue_many=True)
+ self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_batch([sparse], keep_input=True, batch_size=2)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_batch(
+ [sparse], keep_input=True, batch_size=2, enqueue_many=True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
class BatchJoinTest(tf.test.TestCase):
@@ -1125,6 +1202,85 @@ class BatchJoinTest(tf.test.TestCase):
x = tf.train.batch_join([{"c": [12, 12]}], batch_size=8)
self.assertAllEqual((8, 2), x["c"].get_shape().as_list())
+ def _testKeepInputHelper(self, num_threads, enqueue_many):
+ with self.test_session() as sess:
+ batch_size = 5
+ num_batches = 4
+ examples = tf.Variable(0)
+ counter = examples.count_up_to(num_batches * batch_size * 2)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.zeros([1, 1], dtype=tf.int64),
+ values=tf.stack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ to_batch = [counter, sparse_counter, "string"]
+ if enqueue_many:
+ to_batch = tf.train.batch(to_batch, 1)
+ keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2)))
+ batched = tf.train.maybe_batch_join(
+ [to_batch] * num_threads, keep_input, batch_size,
+ enqueue_many=enqueue_many)
+ tf.initialize_all_variables().run()
+ tf.initialize_local_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for _ in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual([0] * batch_size, np.mod(results[0], 2),)
+ self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2),)
+ self.assertAllEqual([b"string"] * batch_size, results[2])
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testSingleThreadKeepInput(self):
+ self._testKeepInputHelper(1, False)
+
+ def testSingleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(1, True)
+
+ def testMultipleThreadKeepInput(self):
+ self._testKeepInputHelper(5, False)
+
+ def testMultipleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(5, True)
+
+ def testMaybeBatchedSparseTensorInferredShape(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_batch_join(
+ [[sparse]], keep_input=True, batch_size=2)
+ self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_batch_join(
+ [[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
+ self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_batch_join(
+ [[sparse]], keep_input=True, batch_size=2)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_batch_join(
+ [[sparse]], keep_input=True, batch_size=2, enqueue_many=True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
class ShuffleBatchTest(tf.test.TestCase):
@@ -1351,6 +1507,83 @@ class ShuffleBatchTest(tf.test.TestCase):
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
+ def _testKeepInputHelper(self, num_threads, enqueue_many):
+ with self.test_session() as sess:
+ batch_size = 5
+ num_batches = 4
+ examples = tf.Variable(0)
+ counter = examples.count_up_to(num_batches * batch_size * 2)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.zeros([1, 1], dtype=tf.int64),
+ values=tf.stack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ to_batch = [counter, sparse_counter, "string"]
+ if enqueue_many:
+ to_batch = tf.train.batch(to_batch, 1)
+ keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2)))
+ batched = tf.train.maybe_shuffle_batch(
+ to_batch, batch_size, 10, 1, keep_input, num_threads=num_threads,
+ enqueue_many=enqueue_many)
+ tf.initialize_all_variables().run()
+ tf.initialize_local_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for _ in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
+ self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
+ self.assertAllEqual([b"string"] * batch_size, results[2])
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testSingleThreadKeepInput(self):
+ self._testKeepInputHelper(1, False)
+
+ def testSingleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(1, True)
+
+ def testMultipleThreadKeepInput(self):
+ self._testKeepInputHelper(5, False)
+
+ def testMultipleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(5, True)
+
+ def testMaybeBatchedSparseTensorInferredShape(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True)
+ self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_shuffle_batch(
+ [sparse], 2, 10, 1, True, enqueue_many=True)
+ self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_shuffle_batch([sparse], 2, 10, 1, True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_shuffle_batch(
+ [sparse], 2, 10, 1, True, enqueue_many=True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
class ShuffleBatchJoinTest(tf.test.TestCase):
@@ -1581,6 +1814,83 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
+ def _testKeepInputHelper(self, num_threads, enqueue_many):
+ with self.test_session() as sess:
+ batch_size = 5
+ num_batches = 4
+ examples = tf.Variable(0)
+ counter = examples.count_up_to(num_batches * batch_size * 2)
+ sparse_counter = tf.SparseTensor(
+ indices=tf.zeros([1, 1], dtype=tf.int64),
+ values=tf.stack([tf.cast(counter, tf.float32)]),
+ shape=[1])
+ to_batch = [counter, sparse_counter, "string"]
+ if enqueue_many:
+ to_batch = tf.train.batch(to_batch, 1)
+ keep_input = tf.squeeze(tf.equal(0, tf.mod(to_batch[0], 2)))
+ batched = tf.train.maybe_shuffle_batch_join(
+ [to_batch] * num_threads, batch_size, 10, 1, keep_input,
+ enqueue_many=enqueue_many)
+ tf.initialize_all_variables().run()
+ tf.initialize_local_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ for _ in range(num_batches):
+ results = sess.run(batched)
+ self.assertAllEqual([0] * batch_size, np.mod(results[0], 2))
+ self.assertAllEqual([0] * batch_size, np.mod(results[1].values, 2))
+ self.assertAllEqual([b"string"] * batch_size, results[2])
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
+ def testSingleThreadKeepInput(self):
+ self._testKeepInputHelper(1, False)
+
+ def testSingleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(1, True)
+
+ def testMultipleThreadKeepInput(self):
+ self._testKeepInputHelper(5, False)
+
+ def testMultipleThreadKeepInputEnqueueMany(self):
+ self._testKeepInputHelper(5, True)
+
+ def testMaybeBatchedSparseTensorInferredShape(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
+ self.assertAllEqual((2,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeEnqueueMany(self):
+ sparse = tf.SparseTensor(indices=[[0]], values=[1.0], shape=[1])
+ self.assertAllEqual((1,), sparse.shape.get_shape().as_list())
+ batched = tf.train.maybe_shuffle_batch_join(
+ [[sparse]], 2, 10, 1, True, enqueue_many=True)
+ self.assertAllEqual((1,), batched.shape.get_shape().as_list())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRank(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_shuffle_batch_join([[sparse]], 2, 10, 1, True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
+ def testMaybeBatchedSparseTensorInferredShapeUnknownRankEnqueueMany(self):
+ sparse = tf.SparseTensor(
+ indices=tf.placeholder(tf.int64),
+ values=tf.placeholder(tf.float32),
+ shape=tf.placeholder(tf.int64))
+ self.assertIs(None, sparse.shape.get_shape().num_elements())
+ batched = tf.train.maybe_shuffle_batch_join(
+ [[sparse]], 2, 10, 1, True, enqueue_many=True)
+ self.assertIs(None, batched.shape.get_shape().num_elements())
+
if __name__ == "__main__":
tf.test.main()