aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Priya Gupta <priyag@google.com>2018-05-02 08:04:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 08:06:51 -0700
commit72fd2b8e97f301039ac0eb60435cbbddf36212f6 (patch)
treebd3f7a19cf804f3b190a4d798d06fbfeeb45f9a7 /tensorflow/contrib/distribute/python/values.py
parentba1c33faeb6df1ae363888e2e7330e219f0679ea (diff)
Use experimental auto_sharding in multi worker dataset.
PiperOrigin-RevId: 195092992
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py5
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 18afdaa7b0..aaf177d07e 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -27,6 +27,7 @@ import weakref
import six
from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.distribute.python import input_ops
from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
from tensorflow.python.framework import device as tf_device
@@ -651,8 +652,8 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- # TODO(yuefengz, priyag): support efficient sharding.
- worker_input = worker_input.shard(len(worker_device_map), i)
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
worker_input, worker_devices, prefetch_on_device=prefetch_on_device)