diff options
author | Priya Gupta <priyag@google.com> | 2018-05-02 08:04:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-02 08:06:51 -0700 |
commit | 72fd2b8e97f301039ac0eb60435cbbddf36212f6 (patch) | |
tree | bd3f7a19cf804f3b190a4d798d06fbfeeb45f9a7 /tensorflow/contrib/distribute/python/values.py | |
parent | ba1c33faeb6df1ae363888e2e7330e219f0679ea (diff) |
Use experimental auto_sharding in multi worker dataset.
PiperOrigin-RevId: 195092992
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 5 |
1 files changed, 3 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 18afdaa7b0..aaf177d07e 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,6 +27,7 @@ import weakref import six from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.distribute.python import input_ops from tensorflow.contrib.distribute.python import prefetching_ops_v2 from tensorflow.python.eager import context from tensorflow.python.framework import device as tf_device @@ -651,8 +652,8 @@ class MultiWorkerDataset(object): six.iteritems(worker_device_map)): with ops.device(worker): worker_input = dataset_fn() - # TODO(yuefengz, priyag): support efficient sharding. - worker_input = worker_input.shard(len(worker_device_map), i) + worker_input = input_ops.auto_shard_dataset( + worker_input, len(worker_device_map), i) self._datasets[worker] = PerDeviceDataset( worker_input, worker_devices, prefetch_on_device=prefetch_on_device) |