diff options
Diffstat (limited to 'tensorflow/python/data/ops/iterator_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/iterator_ops.py | 82 |
1 files changed, 60 insertions, 22 deletions
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index b6dba4e3ca..3ef22cf981 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function import threading import warnings +from tensorflow.python.compat import compat from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.eager import context @@ -172,13 +173,32 @@ class Iterator(object): nest.assert_same_structure(output_types, output_shapes) if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator( - container="", - shared_name=shared_name, - output_types=nest.flatten( - sparse.as_dense_types(output_types, output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(output_shapes, output_classes))) + if compat.forward_compatible(2018, 8, 3): + if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access + with ops.device("/cpu:0"): + iterator_resource = gen_dataset_ops.iterator_v2( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_v2( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator( + container="", + shared_name=shared_name, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -242,12 +262,29 @@ class Iterator(object): output_classes = nest.map_structure(lambda _: ops.Tensor, output_types) nest.assert_same_structure(output_types, output_shapes) string_handle = ops.convert_to_tensor(string_handle, dtype=dtypes.string) - iterator_resource = gen_dataset_ops.iterator_from_string_handle( - string_handle, - output_types=nest.flatten( - sparse.as_dense_types(output_types, output_classes)), - output_shapes=nest.flatten( - sparse.as_dense_shapes(output_shapes, output_classes))) + if compat.forward_compatible(2018, 8, 3): + if not ops.get_default_graph()._graph_device_function_stack: # pylint: disable=protected-access + with ops.device("/cpu:0"): + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) + else: + iterator_resource = gen_dataset_ops.iterator_from_string_handle( + string_handle, + output_types=nest.flatten( + sparse.as_dense_types(output_types, output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(output_shapes, output_classes))) return Iterator(iterator_resource, None, output_types, output_shapes, output_classes) @@ -462,7 +499,8 @@ class EagerIterator(object): "tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction". format(type(self))) - with ops.device("/device:CPU:0"): + self._device = context.context().device_name + with ops.device("/cpu:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_classes = dataset.output_classes self._output_types = dataset.output_types @@ -471,14 +509,14 @@ class EagerIterator(object): sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.anonymous_iterator( - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name + with ops.colocate_with(ds_variant): + self._resource = gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=self._device) def __iter__(self): return self |