aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/data/ops/iterator_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/data/ops/iterator_ops.py')
-rw-r--r--tensorflow/python/data/ops/iterator_ops.py19
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py
index 35de2f2841..3ef22cf981 100644
--- a/tensorflow/python/data/ops/iterator_ops.py
+++ b/tensorflow/python/data/ops/iterator_ops.py
@@ -499,7 +499,8 @@ class EagerIterator(object):
"tf.data.Dataset.make_initializable_iterator or "
"tf.data.Dataset.make_one_shot_iterator for graph construction".
format(type(self)))
- with ops.device("/device:CPU:0"):
+ self._device = context.context().device_name
+ with ops.device("/cpu:0"):
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
self._output_classes = dataset.output_classes
self._output_types = dataset.output_types
@@ -508,14 +509,14 @@ class EagerIterator(object):
sparse.as_dense_types(self._output_types, self._output_classes))
self._flat_output_shapes = nest.flatten(
sparse.as_dense_shapes(self._output_shapes, self._output_classes))
- self._resource = gen_dataset_ops.anonymous_iterator(
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- gen_dataset_ops.make_iterator(ds_variant, self._resource)
- # Delete the resource when this object is deleted
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device="/device:CPU:0")
- self._device = context.context().device_name
+ with ops.colocate_with(ds_variant):
+ self._resource = gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ gen_dataset_ops.make_iterator(ds_variant, self._resource)
+ # Delete the resource when this object is deleted
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=self._device)
def __iter__(self):
return self