aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Jiri Simsa <jsimsa@google.com>2018-09-25 13:42:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 13:46:54 -0700
commit348478f642216cf3cbe1eb67b875252d8e6a6418 (patch)
treec4c7afd4283506b2c413e429cab02fc547e13481 /tensorflow/contrib/distribute
parent976fb3105312bb17accebcbca2ebae906bcf99fb (diff)
[tf.data] Adding a private method for (recursively) tracking dataset inputs.
PiperOrigin-RevId: 214495925
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
index 1ff60c0762..492d82f6a1 100644
--- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -155,10 +155,11 @@ class _PrefetchToDeviceIterator(object):
# pylint: enable=protected-access
-class _PrefetchToDeviceDataset(dataset_ops.Dataset):
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
"""A `Dataset` whose iterator prefetches elements to other device(s)."""
def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
self._devices = devices
self._buffer_size = buffer_size if buffer_size is not None else 1