diff options
author | 2018-04-19 13:19:27 -0700 | |
---|---|---|
committer | 2018-04-19 13:21:45 -0700 | |
commit | 55706e693ab20f6200061fb73067cbf27707cccd (patch) | |
tree | 75233528da3d4c6b3f211736542baf25190a98e3 /tensorflow/contrib/distribute/python/values.py | |
parent | b6686d2808b40ed985db2151bcf31961b53e49f5 (diff) |
Support various shapes in TPU DistributionStrategy.
PiperOrigin-RevId: 193563912
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 18fedd2775..62016c3a78 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -570,6 +570,39 @@ class PerDeviceDataset(object): dataset_iterator, self._devices, self._prefetch_on_device) +class MultiIterator(object): + """Iterator that returns results of multiple get_next()s.""" + + def __init__(self, dataset_iterator, iterations): + self._dataset_iterator = dataset_iterator + self._iterations = iterations + + def get_next(self, name=None): + return [ + self._dataset_iterator.get_next(name=name) + for _ in range(self._iterations) + ] + + @property + def initializer(self): + return self._dataset_iterator.initializer + + +class PerIterationDataset(object): + + def __init__(self, dataset, iterations): + self._dataset = dataset + self._iterations = iterations + + def make_one_shot_iterator(self): + iterator = self._dataset.make_one_shot_iterator() + return MultiIterator(iterator, self._iterations) + + def make_initializable_iterator(self): + iterator = self._dataset.make_initializable_iterator() + return MultiIterator(iterator, self._iterations) + + class MapOutput(object): """Map can result in multiple outputs per device.""" |