diff options
Diffstat (limited to 'tensorflow/contrib/data/python/ops/dataset_ops.py')
-rw-r--r-- | tensorflow/contrib/data/python/ops/dataset_ops.py | 38 |
1 files changed, 21 insertions, 17 deletions
diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 949453bb73..6ef960037f 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -258,11 +258,12 @@ class Iterator(object): # initializers that simply reset their state to the beginning. raise ValueError("Iterator does not have an initializer.") - def make_initializer(self, dataset): + def make_initializer(self, dataset, name=None): """Returns a `tf.Operation` that initializes this iterator on `dataset`. Args: dataset: A `Dataset` with compatible structure to this iterator. + name: (Optional.) A name for the created operation. Returns: A `tf.Operation` that can be run to initialize this iterator on the given @@ -272,22 +273,25 @@ class Iterator(object): TypeError: If `dataset` and this iterator do not have a compatible element structure. """ - nest.assert_same_structure(self._output_types, dataset.output_types) - nest.assert_same_structure(self._output_shapes, dataset.output_shapes) - for iterator_dtype, dataset_dtype in zip( - nest.flatten(self._output_types), nest.flatten(dataset.output_types)): - if iterator_dtype != dataset_dtype: - raise TypeError( - "Expected output types %r but got dataset with output types %r." % - (self._output_types, dataset.output_types)) - for iterator_shape, dataset_shape in zip( - nest.flatten(self._output_shapes), nest.flatten(dataset.output_shapes)): - if not iterator_shape.is_compatible_with(dataset_shape): - raise TypeError("Expected output shapes compatible with %r but got " - "dataset with output shapes %r." % - (self._output_shapes, dataset.output_shapes)) - return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(), - self._iterator_resource) + with ops.name_scope(name, "make_initializer") as name: + nest.assert_same_structure(self._output_types, dataset.output_types) + nest.assert_same_structure(self._output_shapes, dataset.output_shapes) + for iterator_dtype, dataset_dtype in zip( + nest.flatten(self._output_types), nest.flatten(dataset.output_types)): + if iterator_dtype != dataset_dtype: + raise TypeError( + "Expected output types %r but got dataset with output types %r." % + (self._output_types, dataset.output_types)) + for iterator_shape, dataset_shape in zip( + nest.flatten(self._output_shapes), + nest.flatten(dataset.output_shapes)): + if not iterator_shape.is_compatible_with(dataset_shape): + raise TypeError("Expected output shapes compatible with %r but got " + "dataset with output shapes %r." % + (self._output_shapes, dataset.output_shapes)) + return gen_dataset_ops.make_iterator(dataset.make_dataset_resource(), + self._iterator_resource, + name=name) def get_next(self, name=None): """Returns a nested structure of `tf.Tensor`s containing the next element. |