diff options
author | Jiri Simsa <jsimsa@google.com> | 2018-09-25 13:42:46 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 13:46:54 -0700 |
commit | 348478f642216cf3cbe1eb67b875252d8e6a6418 (patch) | |
tree | c4c7afd4283506b2c413e429cab02fc547e13481 /tensorflow/contrib/distribute | |
parent | 976fb3105312bb17accebcbca2ebae906bcf99fb (diff) |
[tf.data] Adding a private method for (recursively) tracking dataset inputs.
PiperOrigin-RevId: 214495925
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/prefetching_ops_v2.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py index 1ff60c0762..492d82f6a1 100644 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py @@ -155,10 +155,11 @@ class _PrefetchToDeviceIterator(object): # pylint: enable=protected-access -class _PrefetchToDeviceDataset(dataset_ops.Dataset): +class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): """A `Dataset` whose iterator prefetches elements to other device(s).""" def __init__(self, input_dataset, devices, buffer_size): + super(_PrefetchToDeviceDataset, self).__init__(input_dataset) self._input_dataset = input_dataset self._devices = devices self._buffer_size = buffer_size if buffer_size is not None else 1 |