diff options
author | 2017-11-14 13:03:31 -0800 | |
---|---|---|
committer | 2017-11-14 13:07:57 -0800 | |
commit | d045a51072eed09b3fcb990ccd3ad4872ce0ada3 (patch) | |
tree | 46a61d76431b3edeb7f2c0573e4d22b6ac2241f4 | |
parent | 6d5793853cfdd27fe806bca4fad0f4e3c3a32b73 (diff) |
Enable prefetching on the resnet50 benchmark for eager.
PiperOrigin-RevId: 175722984
-rw-r--r-- | tensorflow/contrib/eager/python/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/datasets.py | 72 |
2 files changed, 57 insertions, 22 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 6783f7beb0..92746b866a 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -50,21 +50,22 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/contrib/data/python/ops:prefetching_py", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", ], ) -py_test( +cuda_py_test( name = "datasets_test", srcs = ["datasets_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":datasets", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 98e6983658..b559cce6b1 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -20,11 +20,15 @@ from __future__ import print_function import threading +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -32,12 +36,12 @@ _uid_counter = 0 _uid_lock = threading.Lock() -def _iterator_shared_name(): +def _generate_shared_name(prefix): with _uid_lock: global _uid_counter uid = _uid_counter _uid_counter += 1 - return "eager_iterator_{}".format(uid) + return "{}_{}".format(prefix, uid) class Iterator(object): @@ -72,11 +76,12 @@ class Iterator(object): with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access self._output_types = dataset.output_types + self._output_shapes = dataset.output_shapes self._flat_output_types = nest.flatten(dataset.output_types) self._flat_output_shapes = nest.flatten(dataset.output_shapes) self._resource = gen_dataset_ops.iterator( container="", - shared_name=_iterator_shared_name(), + shared_name=_generate_shared_name("eager_iterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -84,6 +89,35 @@ class Iterator(object): self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="/device:CPU:0") self._device = context.context().device_name + self._buffer_resource_handle = None + if not context.context().device_spec.device_type: + is_remote_device = False + else: + is_remote_device = context.context().device_spec.device_type != "CPU" + if is_remote_device: + with ops.device("/device:CPU:0"): + iter_string_handle = gen_dataset_ops.iterator_to_string_handle( + self._resource) + + @function.Defun(dtypes.string) + def remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, self._output_types, self._output_shapes) + return remote_iterator.get_next() + + remote_fn.add_to_graph(None) + target = constant_op.constant("/device:CPU:0") + with ops.device(self._device): + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + string_arg=iter_string_handle, + f=remote_fn, + target_device=target, + buffer_size=10, + thread_pool_size=1, + container="", + shared_name=_generate_shared_name("function_buffer_resource")) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._buffer_resource_handle, handle_device=self._device) def __iter__(self): return self @@ -93,20 +127,20 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" - try: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - with ops.device("/device:CPU:0"): - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - # Copies tensors from CPU to the current device if necessary. - # TODO(rohanj): This should be replaced by the mechanism to have the - # runtime's threads copy tensors to the destination device. with ops.device(self._device): - ret = [array_ops.identity(x) for x in ret] + try: + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + except errors.OutOfRangeError: + raise StopIteration return nest.pack_sequence_as(self._output_types, ret) |