aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.cc3
-rw-r--r--tensorflow/python/training/input.py78
-rw-r--r--tensorflow/python/training/input_test.py107
3 files changed, 173 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc
index f660ede290..eba4677b6d 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.cc
+++ b/tensorflow/core/kernels/padding_fifo_queue.cc
@@ -288,6 +288,9 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent,
if (!s.ok()) {
return s;
}
+ if (element.NumElements() == 0) {
+ return Status::OK();
+ }
auto element_t = element.tensor<T, NDIMS>();
auto parent_t = parent->tensor<T, NDIMS + 1>();
Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices;
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index ae2df782a4..74eb48aa0d 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -413,11 +413,33 @@ def _merge_shapes(shape_list, enqueue_many):
def _shapes(tensor_list_list, shapes, enqueue_many):
+ """Calculate and merge the shapes of incoming tensors.
+
+ Args:
+ tensor_list_list: List of tensor lists.
+ shapes: List of shape tuples corresponding to tensors within the lists.
+ enqueue_many: Boolean describing whether shapes will be enqueued as
+ batches or individual entries.
+
+ Returns:
+ A list of shapes aggregating shape inference info from `tensor_list_list`,
+ or returning `shapes` if it is not `None`.
+
+ Raises:
+ ValueError: If any of the inferred shapes in `tensor_list_list` lack a
+ well defined rank.
+ """
if shapes is None:
- l = len(tensor_list_list[0])
+ len0 = len(tensor_list_list[0])
+
+ for tl in tensor_list_list:
+ for i in xrange(len0):
+ if tl[i].get_shape().ndims is None:
+ raise ValueError("Cannot infer Tensor's rank: %s" % tl[i])
+
shapes = [_merge_shapes(
[tl[i].get_shape().as_list() for tl in tensor_list_list], enqueue_many)
- for i in xrange(l)]
+ for i in xrange(len0)]
return shapes
@@ -437,11 +459,16 @@ def _enqueue(queue, tensor_list, threads, enqueue_many):
queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops))
+def _which_queue(dynamic_pad):
+ return (data_flow_ops.PaddingFIFOQueue if dynamic_pad
+ else data_flow_ops.FIFOQueue)
+
+
# Batching functions ----------------------------------------------------------
def batch(tensor_list, batch_size, num_threads=1, capacity=32,
- enqueue_many=False, shapes=None,
+ enqueue_many=False, shapes=None, dynamic_pad=False,
shared_name=None, name=None):
"""Creates batches of tensors in `tensor_list`.
@@ -465,10 +492,18 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
this exception, however, if this operation is used in your main thread
you are responsible for catching this yourself.
- *N.B.:* You must ensure that either (i) the `shapes` argument is
- passed, or (ii) all of the tensors in `tensor_list` must have
- fully-defined shapes. `ValueError` will be raised if neither of
- these conditions holds.
+ *N.B.:* If `dynamic_pad` is `False`, you must ensure that either
+ (i) the `shapes` argument is passed, or (ii) all of the tensors in
+ `tensor_list` must have fully-defined shapes. `ValueError` will be
+ raised if neither of these conditions holds.
+
+ If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
+ tensors is known, but individual dimensions may have shape `None`.
+ In this case, for each enqueue the dimensions with value `None`
+ may have a variable length; upon dequeue, the output tensors will be padded
+ on the right to the maximum shape of the tensors in the current minibatch.
+ For numbers, this padding takes value 0. For strings, this padding is
+ the empty string. See `PaddingFIFOQueue` for more info.
Args:
tensor_list: The list of tensors to enqueue.
@@ -478,6 +513,9 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
enqueue_many: Whether each tensor in `tensor_list` is a single example.
shapes: (Optional) The shapes for each example. Defaults to the
inferred shapes for `tensor_list`.
+ dynamic_pad: Boolean. Allow variable dimensions in input shapes.
+ The given dimensions are padded upon dequeue so that tensors within a
+ batch have the same shapes.
shared_name: (optional). If set, this queue will be shared under the given
name across multiple sessions.
name: (Optional) A name for the operations.
@@ -496,7 +534,7 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
types = _dtypes([tensor_list])
shapes = _shapes([tensor_list], shapes, enqueue_many)
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
- queue = data_flow_ops.FIFOQueue(
+ queue = _which_queue(dynamic_pad)(
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue(queue, tensor_list, num_threads, enqueue_many)
logging_ops.scalar_summary(
@@ -515,7 +553,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32,
# read that many files in parallel due to the number of seeks required).
# Once this is done, batch() can be written as a call to batch_join().
def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
- shapes=None, shared_name=None, name=None):
+ shapes=None, dynamic_pad=False,
+ shared_name=None, name=None):
"""Runs a list of tensors to fill a queue to create batches of examples.
Enqueues a different list of tensors in different threads.
@@ -548,10 +587,18 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
this exception, however, if this operation is used in your main thread
you are responsible for catching this yourself.
- *N.B.:* You must ensure that either (i) the `shapes` argument is
- passed, or (ii) all of the tensors in `tensor_list_list` must have
- fully-defined shapes. `ValueError` will be raised if neither of
- these conditions holds.
+ *N.B.:* If `dynamic_pad` is `False`, you must ensure that either
+ (i) the `shapes` argument is passed, or (ii) all of the tensors in
+ `tensor_list` must have fully-defined shapes. `ValueError` will be
+ raised if neither of these conditions holds.
+
+ If `dynamic_pad` is `True`, it is sufficient that the *rank* of the
+ tensors is known, but individual dimensions may have value `None`.
+ In this case, for each enqueue the dimensions with value `None`
+ may have a variable length; upon dequeue, the output tensors will be padded
+ on the right to the maximum shape of the tensors in the current minibatch.
+ For numbers, this padding takes value 0. For strings, this padding is
+ the empty string. See `PaddingFIFOQueue` for more info.
Args:
tensor_list_list: A list of tuples of tensors to enqueue.
@@ -561,6 +608,9 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
example.
shapes: (Optional) The shapes for each example. Defaults to the
inferred shapes for `tensor_list_list[i]`.
+ dynamic_pad: Boolean. Allow variable dimensions in input shapes.
+ The given dimensions are padded upon dequeue so that tensors within a
+ batch have the same shapes.
shared_name: (Optional) If set, this queue will be shared under the given
name across multiple sessions.
name: (Optional) A name for the operations.
@@ -580,7 +630,7 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False,
types = _dtypes(tensor_list_list)
shapes = _shapes(tensor_list_list, shapes, enqueue_many)
# TODO(josh11b,mrry): Switch to BatchQueue once it is written.
- queue = data_flow_ops.FIFOQueue(
+ queue = _which_queue(dynamic_pad)(
capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name)
_enqueue_join(queue, tensor_list_list, enqueue_many)
logging_ops.scalar_summary(
diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py
index b265c6e3c4..667ce43b7f 100644
--- a/tensorflow/python/training/input_test.py
+++ b/tensorflow/python/training/input_test.py
@@ -402,6 +402,35 @@ class BatchTest(tf.test.TestCase):
for thread in threads:
thread.join()
+ def testOneThreadDynamicPad(self):
+ with self.test_session() as sess:
+ batch_size = 10
+ num_batches = 3
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_batches * batch_size)
+ string = tf.tile(["string"], tf.to_int32(tf.pack([counter])))
+ tf.initialize_all_variables().run()
+ batched = tf.train.batch(
+ [counter, string], batch_size=batch_size, dynamic_pad=True)
+ threads = tf.train.start_queue_runners()
+
+ for i in range(num_batches):
+ results = sess.run(batched)
+ expected_results = np.arange(i * batch_size, (i + 1) * batch_size)
+ max_len = expected_results[-1]
+ self.assertAllEqual(results[0], expected_results)
+ expected_strings = [
+ [b"string"] * rep + [b""] * (max_len - rep)
+ for rep in expected_results]
+ self.assertAllEqual(results[1], expected_strings)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
def testOneThreadEnqueueMany(self):
with self.test_session() as sess:
batch_size = 10
@@ -491,6 +520,12 @@ class BatchTest(tf.test.TestCase):
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
+ def testCannotInferRankError(self):
+ with self.test_session():
+ x = tf.placeholder(dtype=tf.int64)
+ with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
+ tf.train.batch([x], batch_size=2)
+
class BatchJoinTest(tf.test.TestCase):
@@ -561,6 +596,70 @@ class BatchJoinTest(tf.test.TestCase):
for thread in threads:
thread.join()
+ def testTwoThreadsDynamicPad(self):
+ with self.test_session() as sess:
+ # Two threads, the first generates (0..69, ["a"] * 1..70).
+ num_a = 70
+ zero64 = tf.constant(0, dtype=tf.int64)
+ examples = tf.Variable(zero64)
+ counter = examples.count_up_to(num_a)
+
+ # The second generates (99, ["b"] * 99) 90 times and then stops.
+ num_b = 90
+ ninety_nine = tf.train.limit_epochs(
+ tf.constant(99, dtype=tf.int64), num_b)
+
+ # These get joined together and grouped into batches of 5.
+ batch_size = 5
+ a = tf.tile(["a"], tf.to_int32(tf.pack([counter + 1])))
+ b = tf.tile(["b"], tf.to_int32(tf.pack([ninety_nine])))
+ batched = tf.train.batch_join(
+ [[counter, a],
+ [ninety_nine, b]],
+ batch_size=batch_size, dynamic_pad=True)
+ tf.initialize_all_variables().run()
+ threads = tf.train.start_queue_runners()
+
+ # Should see the "a" and "b" threads mixed together.
+ all_a = []
+ count_string_a = []
+ seen_b = 0
+ saw_both = 0
+ num_batches = (num_a + num_b) // batch_size
+ for i in range(num_batches):
+ results = sess.run(batched)
+ tf.logging.info("Batch %d: %s", i, results[0])
+ self.assertEqual(len(results[0]), batch_size)
+ self.assertEqual(len(results[1]), batch_size)
+ for s in results[1]:
+ if s[0] == b"b":
+ self.assertAllEqual(s, [b"b"] * 99)
+ else:
+ count_string_a.append(sum(x == b"a" for x in s))
+ which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"]
+ which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"]
+ self.assertEqual(len(which_a) + len(which_b), batch_size)
+ if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ all_a.extend([results[0][i] for i in which_a])
+ seen_b += len(which_b)
+ self.assertAllEqual([99] * len(which_b),
+ [results[0][i] for i in which_b])
+
+ # Some minimum level of mixing of the results of both threads.
+ self.assertGreater(saw_both, 1)
+
+ # Verify the order of results from "a" were preserved.
+ self.assertAllEqual( # tiled "a" with counter + 1
+ count_string_a, np.arange(num_a) + 1)
+ self.assertAllEqual(all_a, np.arange(num_a))
+ self.assertEqual(seen_b, num_b)
+
+ # Reached the limit.
+ with self.assertRaises(tf.errors.OutOfRangeError):
+ sess.run(batched)
+ for thread in threads:
+ thread.join()
+
def testSharedName(self):
with self.test_session():
batch_size = 10
@@ -576,6 +675,12 @@ class BatchJoinTest(tf.test.TestCase):
"s: 'SHARED_NAME_XYZ'",
batched[0].op.inputs[0].op.node_def.attr["shared_name"])
+ def testCannotInferRankError(self):
+ with self.test_session():
+ x = tf.placeholder(dtype=tf.int64)
+ with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"):
+ tf.train.batch_join([[x]], batch_size=2)
+
class ShuffleBatchTest(tf.test.TestCase):
@@ -732,7 +837,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase):
which_a = [i for i, s in enumerate(results[2]) if s == b"a"]
which_b = [i for i, s in enumerate(results[2]) if s == b"b"]
self.assertEqual(len(which_a) + len(which_b), batch_size)
- if len(which_a) > 0 and len(which_b) > 0: saw_both += 1
+ if which_a and which_b: saw_both += 1
all_a.extend([results[0][i] for i in which_a])
seen_b += len(which_b)
self.assertAllEqual([99] * len(which_b),