aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/dataset_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/dataset_ops.py')
-rw-r--r--tensorflow/contrib/data/python/ops/dataset_ops.py38
1 files changed, 21 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py
index 949453bb73..6ef960037f 100644
--- a/tensorflow/contrib/data/python/ops/dataset_ops.py
+++ b/tensorflow/contrib/data/python/ops/dataset_ops.py
@@ -258,11 +258,12 @@ class Iterator(object):
# initializers that simply reset their state to the beginning.
raise ValueError("Iterator does not have an initializer.")
- def make_initializer(self, dataset):
+ def make_initializer(self, dataset, name=None):
"""Returns a `tf.Operation` that initializes this iterator on `dataset`.
Args:
dataset: A `Dataset` with compatible structure to this iterator.
+ name: (Optional.) A name for the created operation.
Returns:
A `tf.Operation` that can be run to initialize this iterator on the given
@@ -272,22 +273,25 @@ class Iterator(object):
TypeError: If `dataset` and this iterator do not have a compatible
element structure.
"""
- nest.assert_same_structure(self._output_types, dataset.output_types)
- nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
- for iterator_dtype, dataset_dtype in zip(
- nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
- if iterator_dtype != dataset_dtype:
- raise TypeError(
- "Expected output types %r but got dataset with output types %r." %
- (self._output_types, dataset.output_types))
- for iterator_shape, dataset_shape in zip(
- nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)):
- if not iterator_shape.is_compatible_with(dataset_shape):
- raise TypeError("Expected output shapes compatible with %r but got "
- "dataset with output shapes %r." %
- (self._output_shapes, dataset.output_shapes))
- return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
- self._iterator_resource)
+ with ops.name_scope(name, "make_initializer") as name:
+ nest.assert_same_structure(self._output_types, dataset.output_types)
+ nest.assert_same_structure(self._output_shapes, dataset.output_shapes)
+ for iterator_dtype, dataset_dtype in zip(
+ nest.flatten(self._output_types), nest.flatten(dataset.output_types)):
+ if iterator_dtype != dataset_dtype:
+ raise TypeError(
+ "Expected output types %r but got dataset with output types %r." %
+ (self._output_types, dataset.output_types))
+ for iterator_shape, dataset_shape in zip(
+ nest.flatten(self._output_shapes),
+ nest.flatten(dataset.output_shapes)):
+ if not iterator_shape.is_compatible_with(dataset_shape):
+ raise TypeError("Expected output shapes compatible with %r but got "
+ "dataset with output shapes %r." %
+ (self._output_shapes, dataset.output_shapes))
+ return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(),
+ self._iterator_resource,
+ name=name)
def get_next(self, name=None):
"""Returns a nested structure of `tf.Tensor`s containing the next element.