diff options
author | 2018-09-28 10:25:42 -0700 | |
---|---|---|
committer | 2018-09-28 10:29:45 -0700 | |
commit | 301e3043e67493ce3777d2b36b43d0210f7b920c (patch) | |
tree | 6fd48fe19adc5bbfc5b31568332bd392104e563f /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | b47f0b1b0ac8047d53a824f4ca82a12387a16e4d (diff) |
Disable auto_shard for MirroredStrategy by default.
We will re-enable it when it is more robust.
PiperOrigin-RevId: 214956066
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 504f45a695..93d42e09a2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input data to devices. + auto_shard_dataset: whether to auto-shard the dataset when there are + multiple workers. """ def __init__(self, @@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_tower_ops=None, - prefetch_on_device=None): + prefetch_on_device=None, + auto_shard_dataset=False): super(MirroredStrategy, self).__init__() self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + self._auto_shard_dataset = auto_shard_dataset # Rememeber num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( @@ -477,7 +481,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) + self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), |