aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/training/python
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-02-01 17:58:54 -0800
committerGravatar Yifei Feng <yifeif@google.com>2018-02-01 17:58:54 -0800
commit7ef914f41f1b376eacf41ba99a78491190c3a949 (patch)
tree186d6b07e8827e682a278e97694a4d7100509b0e /tensorflow/contrib/training/python
parent73019bc43d81c781b591407f97f409b8570c6115 (diff)
parentff81ca3d1303ec3ad178113a3398f8f1cac0304d (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/training/python')
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset.py200
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py355
2 files changed, 555 insertions, 0 deletions
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
new file mode 100644
index 0000000000..409aba817c
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset.py
@@ -0,0 +1,200 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import nest as tf_nest
+
+
+class _PrependFromQueueAndPaddedBatchDataset(dataset_ops.Dataset):
+ """A `Dataset` that prepends a queue to another `Dataset`.
+
+ A vector of handles to the queue is returned as the first component of
+ the associated iterator. This vector can be passed to
+ `enqueue_in_queue_dataset` to add new elements to the queue.
+ """
+
+ def __init__(self, input_dataset, batch_size, padded_shapes, padding_values):
+ """Initialize `PrependFromQueueAndPaddedBatchDataset`."""
+ super(_PrependFromQueueAndPaddedBatchDataset, self).__init__()
+ if sparse.any_sparse(input_dataset.output_classes):
+ raise TypeError(
+ "Batching of padded sparse tensors is not currently supported")
+ self._input_dataset = input_dataset
+ self._batch_size = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ # pylint: disable=protected-access
+ if padded_shapes is None:
+ self._padded_shapes = nest.map_structure(
+ dataset_ops._partial_shape_to_tensor, input_dataset.output_shapes)
+ else:
+ self._padded_shapes = nest.map_structure_up_to(
+ input_dataset.output_shapes, dataset_ops._partial_shape_to_tensor,
+ padded_shapes)
+ padding_values = (
+ padding_values if padding_values is not None else
+ dataset_ops._default_padding(input_dataset))
+ self._padding_values = nest.map_structure_up_to(
+ input_dataset.output_shapes, dataset_ops._padding_value_to_tensor,
+ padding_values, input_dataset.output_types)
+ # pylint: enable=protected-access
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return gen_dataset_ops.prepend_from_queue_and_padded_batch_dataset(
+ self._input_dataset._as_variant_tensor(),
+ batch_size=self._batch_size,
+ padded_shapes=[
+ ops.convert_to_tensor(s, dtype=dtypes.int64)
+ for s in nest.flatten(self._padded_shapes)
+ ],
+ padding_values=nest.flatten(self._padding_values),
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return (ops.Tensor, self._input_dataset.output_classes)
+
+ def _as_batch_shape(self, shape_like):
+ return tensor_shape.vector(None).concatenate(
+ tensor_util.constant_value_as_shape(shape_like))
+
+ @property
+ def output_shapes(self):
+ # First output is a variant representing the Queue
+ return (tensor_shape.vector(None),
+ nest.map_structure(self._as_batch_shape, self._padded_shapes))
+
+ @property
+ def output_types(self):
+ # First output is a variant representing the Queue
+ return (dtypes.variant, self._input_dataset.output_types)
+
+
+def prepend_from_queue_and_padded_batch_dataset(batch_size,
+ padding_values=None,
+ padded_shapes=None):
+ """A transformation that prepends a queue to a `Dataset` and batches results.
+
+ A vector of handles to the queue is returned as the first component of the
+ associated iterator. This vector can be passed to `enqueue_in_queue_dataset`
+ to add new elements to the queue.
+
+ Below is an example of how this dataset might be used to split incoming
+ variable-length sequences into "head" and "rest" parts, where "rest" parts
+ are re-enqueued back into the dataset. A more realistic example would
+ perform some calculation on the "head" and modify some components of "rest"
+ with the result (before re-enqueueing).
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([2*x for x in range(10)])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(lambda count: (count, tf.ones((count,))))
+ # Emit a queue we can prepend to, and counts/values as padded batch.
+ dataset = dataset.apply(
+ tf.contrib.training.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=10))
+ dataset = dataset.prefetch(1)
+
+ iterator = dataset.make_one_shot_iterator()
+ queue, (count, padded_value) = iterator.get_next()
+
+ # Split the padded_value into two pieces: head and rest
+ rest_indices = tf.squeeze(tf.where(count > 3), axis=1)
+ bound = tf.minimum(3, tf.reduce_max(count))
+ value_head = padded_value[:, :bound]
+ count_rest = tf.gather(count - 3, rest_indices)
+ value_rest = tf.gather(padded_value[:, bound:], rest_indices)
+ queue_rest = tf.gather(queue, rest_indices)
+ enqueue_rest_op = tf.contrib.training.enqueue_in_queue_dataset(
+ queue_rest, (count_rest, value_rest))
+ with tf.control_dependencies([enqueue_rest_op]):
+ calculation = fn(value_head)
+
+ while True: # Will raise OutOfRange when finished with all pieces.
+ session.run(calculation)
+ ```
+
+ Args:
+ batch_size: `int64` scalar tensor. The batch size to use when performing
+ padded batching.
+ padding_values: (optional) Nested tuple of scalar tensors. If provided,
+ the structure and dtypes of padding_values should match that of
+ incoming dataset's `output_types`.
+ padded_shapes: (optional) Nested tuple of `int64` vector tensors.
+ If provided, the structure must match that of the incoming dataset's
+ `output_types`. If not provided, the incoming dataset's `output_shapes`
+ is used. Any unknown (`None` or `-1`) dimensions in the shapes are
+ treated as being unique per-batch: for each batch time, an unknown
+ dimension is replaced with the maximum given value of this dimension
+ across all tensors for the given component in the batch.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _PrependFromQueueAndPaddedBatchDataset(
+ dataset,
+ batch_size=batch_size,
+ padding_values=padding_values,
+ padded_shapes=padded_shapes)
+
+ return _apply_fn
+
+
+def enqueue_in_queue_dataset(queue, components):
+ """Enqueue components into queue from `PrependFromQueueAndPaddedBatchDataset`.
+
+ The components' dtypes and shapes must be compatible with the `output_shapes`
+ attribute of the `dataset` created by
+ `prepend_from_queue_and_padded_batch_dataset`. This operation supports both
+ non-batched and batched modes.
+
+ For more details, see the example in the docstring for
+ `prepend_from_queue_and_padded_batch_dataset`.
+
+ Args:
+ queue: `variant` scalar or vector tensor.
+ The tensor emitted by the first component of the iterator associated with
+ `prepend_from_queue_and_padded_batch_dataset`. If this is a scalar,
+ then the `components` input tensors should not have a prepended batch
+ dimension.
+ components: Nested tuple of tensors, each with a leading batch dimension
+ if `queue` is a vector. The structure, dtypes, and shapes
+ (excluding batch dimension) must match the nested tuples
+ `dataset.output_types[1]` and `dataset.output_shapes[1]` (the non-queue
+ output types and shapes) of the `dataset` emitted by
+ the original `prepend_from_queue_and_padded_batch_dataset` call.
+
+ Returns:
+ An `Operation` that enqueues `components` into the dataset(s) associated
+ with entries of `queue`.
+ """
+ return gen_dataset_ops.enqueue_in_queue_dataset(
+ queue=queue, components=tf_nest.flatten(components))
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
new file mode 100644
index 0000000000..0338f409a2
--- /dev/null
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -0,0 +1,355 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for TensorQueueDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
+from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class PrependFromQueueAndPaddedBatchDatasetTest(test.TestCase):
+
+ def testNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ self.assertEqual((dtypes.variant, dtypes.int32), dataset.output_types)
+ self.assertAllEqual(([None],) * 2,
+ [x.as_list() for x in dataset.output_shapes])
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertEqual([0], self.evaluate(value))
+ self.assertEqual([1], self.evaluate(value))
+ self.assertEqual([2], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertAllEqual([0, 1], self.evaluate(value))
+ self.assertAllEqual([2], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedWithBiggerPaddingNoEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=2, padded_shapes=[3]))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ self.assertAllEqual([[0, 0, 0], [1, 0, 0]], self.evaluate(value))
+ self.assertAllEqual([[2, 0, 0]], self.evaluate(value))
+ with self.assertRaisesOpError("End of sequence"):
+ self.evaluate(value)
+
+ def testBatchedWithBiggerPaddingOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0], [1], [2]])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=1, padded_shapes=[3]))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ with self.test_session() as sess:
+ self.assertAllEqual([[0, 0, 0]], sess.run(value))
+ value_1, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[1, 0, 0]], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[-1, 0, 0]], value_2)
+ value_3 = sess.run(value)
+ self.assertAllEqual([[1, 0, 0]], value_3)
+ value_4, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([[2, 0, 0]], value_4)
+ value_5 = sess.run(value)
+ self.assertAllEqual([[-2, 0, 0]], value_5)
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ with self.test_session() as sess:
+ self.assertEqual([0], sess.run(value))
+ value_1, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([1], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([-1], value_2)
+ value_3 = sess.run(value)
+ self.assertEqual([1], value_3)
+ value_4, _ = sess.run([value, enqueue_negative])
+ self.assertEqual([2], value_4)
+ value_5 = sess.run(value)
+ self.assertEqual([-2], value_5)
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testBatchedOneEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=2))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_negative = tqd.enqueue_in_queue_dataset(queue_handle, -value)
+ enqueue_zeroth = tqd.enqueue_in_queue_dataset([queue_handle[0]],
+ array_ops.expand_dims(
+ value[0], axis=0))
+ with self.test_session() as sess:
+ value_0, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([0, 1], value_0)
+ value_1, _ = sess.run([value, enqueue_zeroth])
+ self.assertAllEqual([0, -1], value_1)
+ value_2, _ = sess.run([value, enqueue_negative])
+ self.assertAllEqual([0, 2], value_2)
+ self.assertAllEqual([0, -2], sess.run(value))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testManyEnqueue(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue_many_more = [
+ tqd.enqueue_in_queue_dataset(queue_handle, value + 100 + i)
+ for i in range(1000)
+ ]
+ with self.test_session() as sess:
+ value_0, _ = sess.run((value, enqueue_many_more))
+ self.assertEqual([0], value_0)
+ rest = []
+ for _ in range(1000):
+ rest.append(sess.run(value))
+ self.assertEquals([[100 + i] for i in range(1000)], sorted(rest))
+ # Going back to the original input.
+ value_1, _ = sess.run((value, enqueue_many_more))
+ self.assertEqual(1, value_1)
+ rest = []
+ for _ in range(1000):
+ rest.append(sess.run(value))
+ self.assertEquals([[100 + i + 1] for i in range(1000)], sorted(rest))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(value)
+
+ def testEnqueueWithPrefetch(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ # Prefetching will request additional values before they are
+ # available to the queue.
+ dataset = dataset.prefetch(buffer_size=3)
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+ enqueue = tqd.enqueue_in_queue_dataset(queue_handle, value + 1)
+ with self.test_session() as sess:
+ i = 0
+ while i < 4:
+ received, _ = sess.run((value, enqueue))
+ if received.size > 0:
+ self.assertAllEqual([i], received)
+ i += 1
+ received_last = False
+ while True:
+ try:
+ received = sess.run(value)
+ if received.size > 0:
+ self.assertAllEqual([4], received)
+ received_last = True
+ except errors.OutOfRangeError:
+ break
+ self.assertTrue(received_last)
+
+ def testDatasetWithPaddedShapeSmallerThanInputFails(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([[0, 0, 0]]).repeat(None)
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=1, padded_shapes=[2]))
+ iterator = dataset.make_one_shot_iterator()
+ _, value = iterator.get_next()
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(
+ r"Incompatible input shapes at component 0 between "
+ r"input dataset this dataset: \[3\] vs. \[2\]"):
+ sess.run(value)
+
+ def testEnqueueWithIncompatibleInputsFailsWithInformativeError(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0]).repeat(None)
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ iterator = dataset.make_one_shot_iterator()
+ queue_handle, value = iterator.get_next()
+
+ enqueue_bad_structure = tqd.enqueue_in_queue_dataset(
+ queue_handle, (value, value))
+ enqueue_bad_dtype = tqd.enqueue_in_queue_dataset(queue_handle,
+ np.array(
+ [1.0],
+ dtype=np.float32))
+ enqueue_bad_shape_no_batch_dim = tqd.enqueue_in_queue_dataset(
+ queue_handle, ([1],))
+ enqueue_bad_shape = tqd.enqueue_in_queue_dataset(queue_handle,
+ np.array(
+ [[1]], dtype=np.int32))
+
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(
+ "mismatched number of tensors. Queue expects 1 tensors but "
+ "tried to insert 2"):
+ sess.run(enqueue_bad_structure)
+ with self.assertRaisesOpError(r"Expected component 0 to have batched "
+ r"shape \[1,...\], but saw shape: \[\]"):
+ sess.run(enqueue_bad_shape_no_batch_dim)
+ with self.assertRaisesOpError(
+ r"mismatched shapes at component 0. Attempted to insert tensor "
+ r"with shape \[1\] but queue expected shape: \[\]"):
+ sess.run(enqueue_bad_shape)
+ with self.assertRaisesOpError(
+ r"mismatched dtypes at component 0. Attempted to insert tensor "
+ r"of type float but queue expected type: int32"):
+ sess.run(enqueue_bad_dtype)
+
+ def testEnqueueWithPaddedBatchFailsWithInformativeError(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 1, 2])
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=1))
+ with self.assertRaisesRegexp(
+ TypeError, r"Unable to create padding for field of type 'variant'"):
+ dataset.padded_batch(batch_size=10, padded_shapes=[1])
+
+ def testOneEnqueueWithPadding(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(
+ lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype)))
+ # Emit a queue we can prepend to, and counts/values as padded
+ # batch.
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=3))
+
+ iterator = dataset.make_one_shot_iterator()
+ queue, (count, padded_value) = iterator.get_next()
+
+ # Split the padded_value into two pieces: head and rest
+ rest_indices = array_ops.squeeze(array_ops.where(count > 2), axis=1)
+ bound = math_ops.minimum(2, math_ops.reduce_max(count))
+ value_head = padded_value[:, :bound]
+ count_rest = array_ops.gather(count - 2, rest_indices)
+ value_rest = array_ops.gather(padded_value, rest_indices)[:, bound:]
+ queue_rest = array_ops.gather(queue, rest_indices)
+ enqueue_rest_op = tqd.enqueue_in_queue_dataset(queue_rest,
+ (count_rest, value_rest))
+ with ops.control_dependencies([enqueue_rest_op]):
+ calc = array_ops.identity(value_head)
+
+ with self.test_session() as sess:
+ self.assertAllEqual([[0, 0], [2, 2], [4, 4]], sess.run(calc))
+ self.assertAllEqual([[4, 4], [6, 6]], sess.run(calc))
+ self.assertAllEqual([[6, 6]], sess.run(calc))
+ self.assertAllEqual([[6, 6]], sess.run(calc))
+ # Get some final batches due to prefetching.
+ for _ in range(3):
+ try:
+ self.assertAllEqual(
+ np.empty(shape=(0, 0), dtype=np.int32), sess.run(calc))
+ except errors.OutOfRangeError as e:
+ self.assertTrue(str(e).startswith("End of sequence"))
+
+ def testNonstandardPadding(self):
+ dataset = dataset_ops.Dataset.from_tensor_slices([0, 2, 4, 6])
+ # Make a dataset of variable-length vectors and their lengths.
+ dataset = dataset.map(
+ lambda c: (c, c * array_ops.ones((c,), dtype=c.dtype)))
+ # Emit a queue we can prepend to, and counts/values as padded
+ # batch.
+ dataset = dataset.apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=3, padding_values=(
+ 0,
+ -1,
+ )))
+
+ iterator = dataset.make_one_shot_iterator()
+ _, (unused_count, padded_value) = iterator.get_next()
+
+ with self.test_session() as sess:
+ self.assertAllEqual([[-1, -1, -1, -1], [2, 2, -1, -1], [4, 4, 4, 4]],
+ sess.run(padded_value))
+ self.assertAllEqual([[6] * 6], sess.run(padded_value))
+ with self.assertRaisesOpError("End of sequence"):
+ sess.run(padded_value)
+
+
+# TODO(ebrevdo): Figure out how to use run_core_tests to test state
+# saving of an iterator that's had some tensors enqueued into its queue.
+class PrependFromQueueAndPaddedBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testPrependFromQueueAndPaddedBatch(self):
+
+ def build_dataset(seq_lens):
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(batch_size=4))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+ def testPrependFromQueueAndPaddedBatchNonDefaultPadding(self):
+
+ def build_dataset(seq_lens):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ padded_shape = [-1]
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ fill_tuple).apply(
+ tqd.prepend_from_queue_and_padded_batch_dataset(
+ batch_size=4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>")))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+
+if __name__ == "__main__":
+ test.main()