aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-04-30 16:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-30 16:16:44 -0700
commit7141ed55dd0f36f698143812b44aeffc6129257b (patch)
tree0eab99cb82f50a78096162209feb16deed666563 /tensorflow/contrib/distribute/python/values.py
parent18343616da47a9c3eab79b5028ac3d8bf786f2ff (diff)
Add MultiNodeDataset and MultiNodeIterator which are intended to work for multi-node distribution strategy.
PiperOrigin-RevId: 194862215
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 466678ef2e..18afdaa7b0 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -29,6 +29,7 @@ import six
from tensorflow.contrib.data.python.ops import batching
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -576,6 +577,100 @@ class PerDeviceDataset(object):
dataset_iterator, self._devices, self._prefetch_on_device)
+class MultiWorkerDataIterator(object):
+ """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`."""
+
+ def __init__(self, iterators, worker_device_map):
+ """Initialize the MultiWorkerDataIterator object.
+
+ Args:
+ iterators: a dict mapping from each worker to an iterator for
+ that worker.
+ worker_device_map: a dict mapping from each worker's devices to a list of
+ devices that belong to this worker.
+
+ Raises:
+ ValueError: if iterators and worker_device_map are not compatible.
+ """
+ self._iterators = iterators
+ self._worker_device_map = worker_device_map
+ if set(self._iterators) != set(self._worker_device_map):
+ raise ValueError("iterators and worker_device_map are not compatible.")
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [iterator.initializer for iterator in self._iterators.values()])
+
+ def get_next(self, name=None):
+ """Scatter the input across hosts and devices."""
+ index = {}
+ for worker, iterator in six.iteritems(self._iterators):
+ if name is not None:
+ d = tf_device.DeviceSpec.from_string(worker)
+ new_name = "%s_%s_%d" % (name, d.job, d.task)
+ else:
+ new_name = None
+ with ops.device(worker):
+ data_per_worker = iterator.get_next(name=new_name)
+
+ worker_devices = self._worker_device_map[worker]
+ # Ungroup these per-device value so as to get a flat map from devices to
+ # values.
+ for d in worker_devices:
+ v = select_device(d, data_per_worker)
+ if d in index:
+ raise ValueError("Duplicated devices in worker_device_map: %r" % v)
+ index[d] = v
+
+ return regroup(index)
+
+
+class MultiWorkerDataset(object):
+ """Like a `tf.data.Dataset` that distributes data to different workers.
+
+ Each worker gets one shard of the input dataset. It is currently not working
+ in
+ eager mode.
+ """
+
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ """Initialize the MultiWorkerDataset object.
+
+ Args:
+ dataset_fn: a function that returns a `tf.data.Dataset`.
+ worker_device_map: a dict mapping from each worker to a list of devices
+ that belong to this worker.
+ prefetch_on_device: whether to prefetch to devices.
+ """
+ self._worker_device_map = worker_device_map
+ self._datasets = {}
+ # TODO(yuefengz, priyag): support different set of jobs for input
+ # processing.
+ for i, (worker, worker_devices) in enumerate(
+ six.iteritems(worker_device_map)):
+ with ops.device(worker):
+ worker_input = dataset_fn()
+ # TODO(yuefengz, priyag): support efficient sharding.
+ worker_input = worker_input.shard(len(worker_device_map), i)
+ self._datasets[worker] = PerDeviceDataset(
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
+
+ def make_one_shot_iterator(self):
+ iterators = {}
+ for worker, dataset in six.iteritems(self._datasets):
+ with ops.device(worker):
+ iterators[worker] = dataset.make_one_shot_iterator()
+ return MultiWorkerDataIterator(iterators, self._worker_device_map)
+
+ def make_initializable_iterator(self):
+ iterators = {}
+ for worker, dataset in six.iteritems(self._datasets):
+ with ops.device(worker):
+ iterators[worker] = dataset.make_initializable_iterator()
+ return MultiWorkerDataIterator(iterators, self._worker_device_map)
+
+
class PerIteration(object):
"""Holds input for multiple iterations at once."""