aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/data/python/ops/prefetching_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/data/python/ops/prefetching_ops.py')
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py350
1 files changed, 350 insertions, 0 deletions
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 21fc17102e..0edd7c9fe9 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -26,10 +26,15 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
+from tensorflow.python.ops import resource_variable_ops
def function_buffering_resource(string_arg,
@@ -345,3 +350,348 @@ def prefetch_to_device(device, buffer_size=None):
return _PrefetchToDeviceDataset(dataset, device, buffer_size)
return _apply_fn
+
+
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.data.Dataset.apply}.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.Dataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = core_gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return core_gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.contrib.data.copy_to_device()` on GPU. Please use "
+ "`Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return core_gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+class _PerDeviceGenerator(dataset_ops.Dataset):
+ """A `dummy` generator dataset."""
+
+ def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
+ source_device, target_device, output_shapes, output_types,
+ output_classes):
+ self._target_device = target_device
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+ self._output_classes = output_classes
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._output_shapes, self._output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes))
+
+ multi_device_iterator_string_handle = (
+ gen_dataset_ops.multi_device_iterator_to_string_handle(
+ multi_device_iterator_resource))
+
+ @function.Defun()
+ def _init_func():
+ return multi_device_iterator_string_handle
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ multi_device_iterator = (
+ gen_dataset_ops.multi_device_iterator_from_string_handle(
+ string_handle=string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+ return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
+ multi_device_iterator=multi_device_iterator,
+ shard_num=shard_num,
+ incarnation_id=incarnation_id,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(unused_string_handle):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return core_gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+class MultiDeviceIterator(object):
+ """An iterator over multiple devices."""
+
+ def __init__(self,
+ dataset,
+ devices,
+ prefetch_buffer_size=1,
+ source_device="/cpu:0"):
+ self._dataset = dataset
+ self._devices = devices
+ self._source_device = source_device
+ self._source_device_tensor = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._dataset.output_shapes,
+ self._dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._dataset.output_types,
+ self._dataset.output_classes))
+
+ # Create the MultiDeviceIterator.
+ with ops.device(self._source_device):
+ self._multi_device_iterator_resource = (
+ gen_dataset_ops.multi_device_iterator(
+ devices=self._devices,
+ shared_name="",
+ container="",
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes))
+
+ # The incarnation ID is used to ensure consistency between the per-device
+ # iterators and the multi-device iterator.
+ self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
+ self._dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._multi_device_iterator_resource)
+
+ # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
+ # initialize the device side of the pipeline. This would allow the
+ # MultiDeviceIterator to choose, for example, to move some transformations
+ # into the device side from its input. It might be useful in rewriting.
+ # Create the per device iterators.
+ self._device_iterators = []
+ i = 0
+ for device in self._devices:
+ ds = _PerDeviceGenerator(
+ i, self._multi_device_iterator_resource, self._incarnation_id,
+ self._source_device_tensor, device, self._dataset.output_shapes,
+ self._dataset.output_types, self._dataset.output_classes)
+ ds = ds.prefetch(prefetch_buffer_size)
+ with ops.device(device):
+ self._device_iterators.append(ds.make_initializable_iterator())
+ i += 1
+
+ device_iterator_initializers = [
+ iterator.initializer for iterator in self._device_iterators
+ ]
+ self._initializer = control_flow_ops.group(*device_iterator_initializers)
+
+ def get_next(self):
+ result = []
+ i = 0
+ for device in self._devices:
+ with ops.device(device):
+ result.append(self._device_iterators[i].get_next())
+ i += 1
+ return result
+
+ @property
+ def initializer(self):
+ return self._initializer