diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-09-28 10:25:42 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 10:29:45 -0700 |
commit | 301e3043e67493ce3777d2b36b43d0210f7b920c (patch) | |
tree | 6fd48fe19adc5bbfc5b31568332bd392104e563f | |
parent | b47f0b1b0ac8047d53a824f4ca82a12387a16e4d (diff) |
Disable auto_shard for MirroredStrategy by default.
We will re-enable it when it is more robust.
PiperOrigin-RevId: 214956066
-rw-r--r-- | tensorflow/contrib/distribute/README.md | 3 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 8 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 9 |
3 files changed, 14 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md index 91a27f97b7..2e025765e4 100644 --- a/tensorflow/contrib/distribute/README.md +++ b/tensorflow/contrib/distribute/README.md @@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use important to shuffle your dataset in your `input_fn`. `MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you -`input_fn`. As a result, each worker gets a fraction of your input data. +`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker +gets a fraction of your input data. ### Performance Tips diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 504f45a695..93d42e09a2 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input data to devices. + auto_shard_dataset: whether to auto-shard the dataset when there are + multiple workers. """ def __init__(self, @@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus=None, num_gpus_per_worker=None, cross_tower_ops=None, - prefetch_on_device=None): + prefetch_on_device=None, + auto_shard_dataset=False): super(MirroredStrategy, self).__init__() self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device + self._auto_shard_dataset = auto_shard_dataset # Rememeber num GPUs which might be needed by `configure` method. if num_gpus is not None and num_gpus_per_worker is not None: raise ValueError( @@ -477,7 +481,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): if self._cluster_spec: return values.MultiWorkerDataset( partial(self._call_dataset_fn, dataset_fn), self._worker_device_map, - self._prefetch_on_device) + self._prefetch_on_device, self._auto_shard_dataset) else: return values.PerDeviceDataset( self._call_dataset_fn(dataset_fn), diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index cce41e7717..327775a729 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -814,7 +814,8 @@ class MultiWorkerDataset(object): eager mode. """ - def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None): + def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None, + auto_shard=False): """Initialize the MultiWorkerDataset object. Args: @@ -822,6 +823,7 @@ class MultiWorkerDataset(object): worker_device_map: a dict mapping from each worker to a list of devices that belong to this worker. prefetch_on_device: whether to prefetch to devices. + auto_shard: whether to auto-shard the dataset. """ self._worker_device_map = worker_device_map self._datasets = {} @@ -831,8 +833,9 @@ class MultiWorkerDataset(object): six.iteritems(worker_device_map)): with ops.device(worker): worker_input = dataset_fn() - worker_input = input_ops.auto_shard_dataset( - worker_input, len(worker_device_map), i) + if auto_shard: + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, |