aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Rohan Jain <rohanj@google.com>2017-11-14 13:03:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-14 13:07:57 -0800
commitd045a51072eed09b3fcb990ccd3ad4872ce0ada3 (patch)
tree46a61d76431b3edeb7f2c0573e4d22b6ac2241f4
parent6d5793853cfdd27fe806bca4fad0f4e3c3a32b73 (diff)
Enable prefetching on the resnet50 benchmark for eager.
PiperOrigin-RevId: 175722984
-rw-r--r--tensorflow/contrib/eager/python/BUILD7
-rw-r--r--tensorflow/contrib/eager/python/datasets.py72
2 files changed, 57 insertions, 22 deletions
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index 6783f7beb0..92746b866a 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -50,21 +50,22 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/eager:context",
],
)
-py_test(
+cuda_py_test(
name = "datasets_test",
srcs = ["datasets_test.py"],
- srcs_version = "PY2AND3",
- deps = [
+ additional_deps = [
":datasets",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 98e6983658..b559cce6b1 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -20,11 +20,15 @@ from __future__ import print_function
import threading
+from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops
@@ -32,12 +36,12 @@ _uid_counter = 0
_uid_lock = threading.Lock()
-def _iterator_shared_name():
+def _generate_shared_name(prefix):
with _uid_lock:
global _uid_counter
uid = _uid_counter
_uid_counter += 1
- return "eager_iterator_{}".format(uid)
+ return "{}_{}".format(prefix, uid)
class Iterator(object):
@@ -72,11 +76,12 @@ class Iterator(object):
with ops.device("/device:CPU:0"):
ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access
self._output_types = dataset.output_types
+ self._output_shapes = dataset.output_shapes
self._flat_output_types = nest.flatten(dataset.output_types)
self._flat_output_shapes = nest.flatten(dataset.output_shapes)
self._resource = gen_dataset_ops.iterator(
container="",
- shared_name=_iterator_shared_name(),
+ shared_name=_generate_shared_name("eager_iterator"),
output_types=self._flat_output_types,
output_shapes=self._flat_output_shapes)
gen_dataset_ops.make_iterator(ds_variant, self._resource)
@@ -84,6 +89,35 @@ class Iterator(object):
self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
handle=self._resource, handle_device="/device:CPU:0")
self._device = context.context().device_name
+ self._buffer_resource_handle = None
+ if not context.context().device_spec.device_type:
+ is_remote_device = False
+ else:
+ is_remote_device = context.context().device_spec.device_type != "CPU"
+ if is_remote_device:
+ with ops.device("/device:CPU:0"):
+ iter_string_handle = gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ @function.Defun(dtypes.string)
+ def remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, self._output_types, self._output_shapes)
+ return remote_iterator.get_next()
+
+ remote_fn.add_to_graph(None)
+ target = constant_op.constant("/device:CPU:0")
+ with ops.device(self._device):
+ self._buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ string_arg=iter_string_handle,
+ f=remote_fn,
+ target_device=target,
+ buffer_size=10,
+ thread_pool_size=1,
+ container="",
+ shared_name=_generate_shared_name("function_buffer_resource"))
+ self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._buffer_resource_handle, handle_device=self._device)
def __iter__(self):
return self
@@ -93,20 +127,20 @@ class Iterator(object):
def next(self):
"""Return the next tf.Tensor from the dataset."""
- try:
- # TODO(ashankar): Consider removing this ops.device() contextmanager
- # and instead mimic ops placement in graphs: Operations on resource
- # handles execute on the same device as where the resource is placed.
- with ops.device("/device:CPU:0"):
- ret = gen_dataset_ops.iterator_get_next(
- self._resource,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- except errors.OutOfRangeError:
- raise StopIteration
- # Copies tensors from CPU to the current device if necessary.
- # TODO(rohanj): This should be replaced by the mechanism to have the
- # runtime's threads copy tensors to the destination device.
with ops.device(self._device):
- ret = [array_ops.identity(x) for x in ret]
+ try:
+ if self._buffer_resource_handle is not None:
+ ret = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=self._buffer_resource_handle,
+ output_types=self._flat_output_types)
+ else:
+ # TODO(ashankar): Consider removing this ops.device() contextmanager
+ # and instead mimic ops placement in graphs: Operations on resource
+ # handles execute on the same device as where the resource is placed.
+ ret = gen_dataset_ops.iterator_get_next(
+ self._resource,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ except errors.OutOfRangeError:
+ raise StopIteration
return nest.pack_sequence_as(self._output_types, ret)