aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-28 10:25:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 10:29:45 -0700
commit301e3043e67493ce3777d2b36b43d0210f7b920c (patch)
tree6fd48fe19adc5bbfc5b31568332bd392104e563f /tensorflow/contrib/distribute/python/values.py
parentb47f0b1b0ac8047d53a824f4ca82a12387a16e4d (diff)
Disable auto_shard for MirroredStrategy by default.
We will re-enable it when it is more robust. PiperOrigin-RevId: 214956066
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index cce41e7717..327775a729 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -814,7 +814,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -822,6 +823,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -831,8 +833,9 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
worker_input,
worker_devices,