aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-28 10:25:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 10:29:45 -0700
commit301e3043e67493ce3777d2b36b43d0210f7b920c (patch)
tree6fd48fe19adc5bbfc5b31568332bd392104e563f /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentb47f0b1b0ac8047d53a824f4ca82a12387a16e4d (diff)
Disable auto_shard for MirroredStrategy by default.
We will re-enable it when it is more robust. PiperOrigin-RevId: 214956066
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 504f45a695..93d42e09a2 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,7 +481,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
self._call_dataset_fn(dataset_fn),