diff options
Diffstat (limited to 'tensorflow/python/data/ops/iterator_ops.py')
-rw-r--r-- | tensorflow/python/data/ops/iterator_ops.py | 19 |
1 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 35de2f2841..3ef22cf981 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -499,7 +499,8 @@ class EagerIterator(object): "tf.data.Dataset.make_initializable_iterator or " "tf.data.Dataset.make_one_shot_iterator for graph construction". format(type(self))) - with ops.device("/device:CPU:0"): + self._device = context.context().device_name + with ops.device("/cpu:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_classes = dataset.output_classes self._output_types = dataset.output_types @@ -508,14 +509,14 @@ class EagerIterator(object): sparse.as_dense_types(self._output_types, self._output_classes)) self._flat_output_shapes = nest.flatten( sparse.as_dense_shapes(self._output_shapes, self._output_classes)) - self._resource = gen_dataset_ops.anonymous_iterator( - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device="/device:CPU:0") - self._device = context.context().device_name + with ops.colocate_with(ds_variant): + self._resource = gen_dataset_ops.anonymous_iterator( + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=self._device) def __iter__(self): return self |