diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 945f450387..4d7516063c 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input data to devices. + auto_shard_dataset: whether to auto-shard the dataset when there are + multiple workers. """ def __init__(self, @@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_tower_ops=None, - prefetch_on_device=None): + prefetch_on_device=None, + auto_shard_dataset=False): super(MirroredStrategy, self).__init__() self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + self._auto_shard_dataset = auto_shard_dataset # Rememeber num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( @@ -477,13 +481,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) + self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( - self._call_dataset_fn(dataset_fn), - self._devices, - self._prefetch_on_device, - source_device=device_util.resolve("/device:CPU:0")) + self._call_dataset_fn(dataset_fn), self._devices, + self._prefetch_on_device) # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. def _run_steps_on_dataset(self, fn, iterator, iterations, |