diff options
-rw-r--r-- | tensorflow/core/kernels/padding_fifo_queue.cc | 3 | ||||
-rw-r--r-- | tensorflow/python/training/input.py | 78 | ||||
-rw-r--r-- | tensorflow/python/training/input_test.py | 107 |
3 files changed, 173 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/padding_fifo_queue.cc b/tensorflow/core/kernels/padding_fifo_queue.cc index f660ede290..eba4677b6d 100644 --- a/tensorflow/core/kernels/padding_fifo_queue.cc +++ b/tensorflow/core/kernels/padding_fifo_queue.cc @@ -288,6 +288,9 @@ Status HandleElementToLargerSlice(const Tensor& element, Tensor* parent, if (!s.ok()) { return s; } + if (element.NumElements() == 0) { + return Status::OK(); + } auto element_t = element.tensor<T, NDIMS>(); auto parent_t = parent->tensor<T, NDIMS + 1>(); Eigen::DSizes<Eigen::DenseIndex, NDIMS + 1> slice_indices; diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py index ae2df782a4..74eb48aa0d 100644 --- a/tensorflow/python/training/input.py +++ b/tensorflow/python/training/input.py @@ -413,11 +413,33 @@ def _merge_shapes(shape_list, enqueue_many): def _shapes(tensor_list_list, shapes, enqueue_many): + """Calculate and merge the shapes of incoming tensors. + + Args: + tensor_list_list: List of tensor lists. + shapes: List of shape tuples corresponding to tensors within the lists. + enqueue_many: Boolean describing whether shapes will be enqueued as + batches or individual entries. + + Returns: + A list of shapes aggregating shape inference info from `tensor_list_list`, + or returning `shapes` if it is not `None`. + + Raises: + ValueError: If any of the inferred shapes in `tensor_list_list` lack a + well defined rank. + """ if shapes is None: - l = len(tensor_list_list[0]) + len0 = len(tensor_list_list[0]) + + for tl in tensor_list_list: + for i in xrange(len0): + if tl[i].get_shape().ndims is None: + raise ValueError("Cannot infer Tensor's rank: %s" % tl[i]) + shapes = [_merge_shapes( [tl[i].get_shape().as_list() for tl in tensor_list_list], enqueue_many) - for i in xrange(l)] + for i in xrange(len0)] return shapes @@ -437,11 +459,16 @@ def _enqueue(queue, tensor_list, threads, enqueue_many): queue_runner.add_queue_runner(queue_runner.QueueRunner(queue, enqueue_ops)) +def _which_queue(dynamic_pad): + return (data_flow_ops.PaddingFIFOQueue if dynamic_pad + else data_flow_ops.FIFOQueue) + + # Batching functions ---------------------------------------------------------- def batch(tensor_list, batch_size, num_threads=1, capacity=32, - enqueue_many=False, shapes=None, + enqueue_many=False, shapes=None, dynamic_pad=False, shared_name=None, name=None): """Creates batches of tensors in `tensor_list`. @@ -465,10 +492,18 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, this exception, however, if this operation is used in your main thread you are responsible for catching this yourself. - *N.B.:* You must ensure that either (i) the `shapes` argument is - passed, or (ii) all of the tensors in `tensor_list` must have - fully-defined shapes. `ValueError` will be raised if neither of - these conditions holds. + *N.B.:* If `dynamic_pad` is `False`, you must ensure that either + (i) the `shapes` argument is passed, or (ii) all of the tensors in + `tensor_list` must have fully-defined shapes. `ValueError` will be + raised if neither of these conditions holds. + + If `dynamic_pad` is `True`, it is sufficient that the *rank* of the + tensors is known, but individual dimensions may have shape `None`. + In this case, for each enqueue the dimensions with value `None` + may have a variable length; upon dequeue, the output tensors will be padded + on the right to the maximum shape of the tensors in the current minibatch. + For numbers, this padding takes value 0. For strings, this padding is + the empty string. See `PaddingFIFOQueue` for more info. Args: tensor_list: The list of tensors to enqueue. @@ -478,6 +513,9 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, enqueue_many: Whether each tensor in `tensor_list` is a single example. shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for `tensor_list`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. shared_name: (optional). If set, this queue will be shared under the given name across multiple sessions. name: (Optional) A name for the operations. @@ -496,7 +534,7 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, types = _dtypes([tensor_list]) shapes = _shapes([tensor_list], shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. - queue = data_flow_ops.FIFOQueue( + queue = _which_queue(dynamic_pad)( capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) _enqueue(queue, tensor_list, num_threads, enqueue_many) logging_ops.scalar_summary( @@ -515,7 +553,8 @@ def batch(tensor_list, batch_size, num_threads=1, capacity=32, # read that many files in parallel due to the number of seeks required). # Once this is done, batch() can be written as a call to batch_join(). def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, - shapes=None, shared_name=None, name=None): + shapes=None, dynamic_pad=False, + shared_name=None, name=None): """Runs a list of tensors to fill a queue to create batches of examples. Enqueues a different list of tensors in different threads. @@ -548,10 +587,18 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, this exception, however, if this operation is used in your main thread you are responsible for catching this yourself. - *N.B.:* You must ensure that either (i) the `shapes` argument is - passed, or (ii) all of the tensors in `tensor_list_list` must have - fully-defined shapes. `ValueError` will be raised if neither of - these conditions holds. + *N.B.:* If `dynamic_pad` is `False`, you must ensure that either + (i) the `shapes` argument is passed, or (ii) all of the tensors in + `tensor_list` must have fully-defined shapes. `ValueError` will be + raised if neither of these conditions holds. + + If `dynamic_pad` is `True`, it is sufficient that the *rank* of the + tensors is known, but individual dimensions may have value `None`. + In this case, for each enqueue the dimensions with value `None` + may have a variable length; upon dequeue, the output tensors will be padded + on the right to the maximum shape of the tensors in the current minibatch. + For numbers, this padding takes value 0. For strings, this padding is + the empty string. See `PaddingFIFOQueue` for more info. Args: tensor_list_list: A list of tuples of tensors to enqueue. @@ -561,6 +608,9 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, example. shapes: (Optional) The shapes for each example. Defaults to the inferred shapes for `tensor_list_list[i]`. + dynamic_pad: Boolean. Allow variable dimensions in input shapes. + The given dimensions are padded upon dequeue so that tensors within a + batch have the same shapes. shared_name: (Optional) If set, this queue will be shared under the given name across multiple sessions. name: (Optional) A name for the operations. @@ -580,7 +630,7 @@ def batch_join(tensor_list_list, batch_size, capacity=32, enqueue_many=False, types = _dtypes(tensor_list_list) shapes = _shapes(tensor_list_list, shapes, enqueue_many) # TODO(josh11b,mrry): Switch to BatchQueue once it is written. - queue = data_flow_ops.FIFOQueue( + queue = _which_queue(dynamic_pad)( capacity=capacity, dtypes=types, shapes=shapes, shared_name=shared_name) _enqueue_join(queue, tensor_list_list, enqueue_many) logging_ops.scalar_summary( diff --git a/tensorflow/python/training/input_test.py b/tensorflow/python/training/input_test.py index b265c6e3c4..667ce43b7f 100644 --- a/tensorflow/python/training/input_test.py +++ b/tensorflow/python/training/input_test.py @@ -402,6 +402,35 @@ class BatchTest(tf.test.TestCase): for thread in threads: thread.join() + def testOneThreadDynamicPad(self): + with self.test_session() as sess: + batch_size = 10 + num_batches = 3 + zero64 = tf.constant(0, dtype=tf.int64) + examples = tf.Variable(zero64) + counter = examples.count_up_to(num_batches * batch_size) + string = tf.tile(["string"], tf.to_int32(tf.pack([counter]))) + tf.initialize_all_variables().run() + batched = tf.train.batch( + [counter, string], batch_size=batch_size, dynamic_pad=True) + threads = tf.train.start_queue_runners() + + for i in range(num_batches): + results = sess.run(batched) + expected_results = np.arange(i * batch_size, (i + 1) * batch_size) + max_len = expected_results[-1] + self.assertAllEqual(results[0], expected_results) + expected_strings = [ + [b"string"] * rep + [b""] * (max_len - rep) + for rep in expected_results] + self.assertAllEqual(results[1], expected_strings) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + def testOneThreadEnqueueMany(self): with self.test_session() as sess: batch_size = 10 @@ -491,6 +520,12 @@ class BatchTest(tf.test.TestCase): "s: 'SHARED_NAME_XYZ'", batched[0].op.inputs[0].op.node_def.attr["shared_name"]) + def testCannotInferRankError(self): + with self.test_session(): + x = tf.placeholder(dtype=tf.int64) + with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): + tf.train.batch([x], batch_size=2) + class BatchJoinTest(tf.test.TestCase): @@ -561,6 +596,70 @@ class BatchJoinTest(tf.test.TestCase): for thread in threads: thread.join() + def testTwoThreadsDynamicPad(self): + with self.test_session() as sess: + # Two threads, the first generates (0..69, ["a"] * 1..70). + num_a = 70 + zero64 = tf.constant(0, dtype=tf.int64) + examples = tf.Variable(zero64) + counter = examples.count_up_to(num_a) + + # The second generates (99, ["b"] * 99) 90 times and then stops. + num_b = 90 + ninety_nine = tf.train.limit_epochs( + tf.constant(99, dtype=tf.int64), num_b) + + # These get joined together and grouped into batches of 5. + batch_size = 5 + a = tf.tile(["a"], tf.to_int32(tf.pack([counter + 1]))) + b = tf.tile(["b"], tf.to_int32(tf.pack([ninety_nine]))) + batched = tf.train.batch_join( + [[counter, a], + [ninety_nine, b]], + batch_size=batch_size, dynamic_pad=True) + tf.initialize_all_variables().run() + threads = tf.train.start_queue_runners() + + # Should see the "a" and "b" threads mixed together. + all_a = [] + count_string_a = [] + seen_b = 0 + saw_both = 0 + num_batches = (num_a + num_b) // batch_size + for i in range(num_batches): + results = sess.run(batched) + tf.logging.info("Batch %d: %s", i, results[0]) + self.assertEqual(len(results[0]), batch_size) + self.assertEqual(len(results[1]), batch_size) + for s in results[1]: + if s[0] == b"b": + self.assertAllEqual(s, [b"b"] * 99) + else: + count_string_a.append(sum(x == b"a" for x in s)) + which_a = [i for i, s in enumerate(results[1]) if s[0] == b"a"] + which_b = [i for i, s in enumerate(results[1]) if s[0] == b"b"] + self.assertEqual(len(which_a) + len(which_b), batch_size) + if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 + all_a.extend([results[0][i] for i in which_a]) + seen_b += len(which_b) + self.assertAllEqual([99] * len(which_b), + [results[0][i] for i in which_b]) + + # Some minimum level of mixing of the results of both threads. + self.assertGreater(saw_both, 1) + + # Verify the order of results from "a" were preserved. + self.assertAllEqual( # tiled "a" with counter + 1 + count_string_a, np.arange(num_a) + 1) + self.assertAllEqual(all_a, np.arange(num_a)) + self.assertEqual(seen_b, num_b) + + # Reached the limit. + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batched) + for thread in threads: + thread.join() + def testSharedName(self): with self.test_session(): batch_size = 10 @@ -576,6 +675,12 @@ class BatchJoinTest(tf.test.TestCase): "s: 'SHARED_NAME_XYZ'", batched[0].op.inputs[0].op.node_def.attr["shared_name"]) + def testCannotInferRankError(self): + with self.test_session(): + x = tf.placeholder(dtype=tf.int64) + with self.assertRaisesRegexp(ValueError, "Cannot infer Tensor's rank"): + tf.train.batch_join([[x]], batch_size=2) + class ShuffleBatchTest(tf.test.TestCase): @@ -732,7 +837,7 @@ class ShuffleBatchJoinTest(tf.test.TestCase): which_a = [i for i, s in enumerate(results[2]) if s == b"a"] which_b = [i for i, s in enumerate(results[2]) if s == b"b"] self.assertEqual(len(which_a) + len(which_b), batch_size) - if len(which_a) > 0 and len(which_b) > 0: saw_both += 1 + if which_a and which_b: saw_both += 1 all_a.extend([results[0][i] for i in which_a]) seen_b += len(which_b) self.assertAllEqual([99] * len(which_b), |