diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-04-30 16:12:33 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-30 16:16:44 -0700 |
commit | 7141ed55dd0f36f698143812b44aeffc6129257b (patch) | |
tree | 0eab99cb82f50a78096162209feb16deed666563 /tensorflow/contrib/distribute/python/values.py | |
parent | 18343616da47a9c3eab79b5028ac3d8bf786f2ff (diff) |
Add MultiNodeDataset and MultiNodeIterator which are intended to work for multi-node distribution strategy.
PiperOrigin-RevId: 194862215
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 466678ef2e..18afdaa7b0 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -29,6 +29,7 @@ import six from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context +from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -576,6 +577,100 @@ class PerDeviceDataset(object): dataset_iterator, self._devices, self._prefetch_on_device) +class MultiWorkerDataIterator(object): + """An iterator (like `tf.data.Iterator`) into a `MultiWorkerDataset`.""" + + def __init__(self, iterators, worker_device_map): + """Initialize the MultiWorkerDataIterator object. + + Args: + iterators: a dict mapping from each worker to an iterator for + that worker. + worker_device_map: a dict mapping from each worker's devices to a list of + devices that belong to this worker. + + Raises: + ValueError: if iterators and worker_device_map are not compatible. + """ + self._iterators = iterators + self._worker_device_map = worker_device_map + if set(self._iterators) != set(self._worker_device_map): + raise ValueError("iterators and worker_device_map are not compatible.") + + @property + def initializer(self): + return control_flow_ops.group( + [iterator.initializer for iterator in self._iterators.values()]) + + def get_next(self, name=None): + """Scatter the input across hosts and devices.""" + index = {} + for worker, iterator in six.iteritems(self._iterators): + if name is not None: + d = tf_device.DeviceSpec.from_string(worker) + new_name = "%s_%s_%d" % (name, d.job, d.task) + else: + new_name = None + with ops.device(worker): + data_per_worker = iterator.get_next(name=new_name) + + worker_devices = self._worker_device_map[worker] + # Ungroup these per-device value so as to get a flat map from devices to + # values. + for d in worker_devices: + v = select_device(d, data_per_worker) + if d in index: + raise ValueError("Duplicated devices in worker_device_map: %r" % v) + index[d] = v + + return regroup(index) + + +class MultiWorkerDataset(object): + """Like a `tf.data.Dataset` that distributes data to different workers. + + Each worker gets one shard of the input dataset. It is currently not working + in + eager mode. + """ + + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + """Initialize the MultiWorkerDataset object. + + Args: + dataset_fn: a function that returns a `tf.data.Dataset`. + worker_device_map: a dict mapping from each worker to a list of devices + that belong to this worker. + prefetch_on_device: whether to prefetch to devices. + """ + self._worker_device_map = worker_device_map + self._datasets = {} + # TODO(yuefengz, priyag): support different set of jobs for input + # processing. + for i, (worker, worker_devices) in enumerate( + six.iteritems(worker_device_map)): + with ops.device(worker): + worker_input = dataset_fn() + # TODO(yuefengz, priyag): support efficient sharding. + worker_input = worker_input.shard(len(worker_device_map), i) + self._datasets[worker] = PerDeviceDataset( + worker_input, worker_devices, prefetch_on_device=prefetch_on_device) + + def make_one_shot_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_one_shot_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + def make_initializable_iterator(self): + iterators = {} + for worker, dataset in six.iteritems(self._datasets): + with ops.device(worker): + iterators[worker] = dataset.make_initializable_iterator() + return MultiWorkerDataIterator(iterators, self._worker_device_map) + + class PerIteration(object): """Holds input for multiple iterations at once.""" |