aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-04-19 13:19:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 13:21:45 -0700
commit55706e693ab20f6200061fb73067cbf27707cccd (patch)
tree75233528da3d4c6b3f211736542baf25190a98e3 /tensorflow/contrib/distribute/python/values.py
parentb6686d2808b40ed985db2151bcf31961b53e49f5 (diff)
Support various shapes in TPU DistributionStrategy.
PiperOrigin-RevId: 193563912
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py33
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 18fedd2775..62016c3a78 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -570,6 +570,39 @@ class PerDeviceDataset(object):
dataset_iterator, self._devices, self._prefetch_on_device)
+class MultiIterator(object):
+ """Iterator that returns results of multiple get_next()s."""
+
+ def __init__(self, dataset_iterator, iterations):
+ self._dataset_iterator = dataset_iterator
+ self._iterations = iterations
+
+ def get_next(self, name=None):
+ return [
+ self._dataset_iterator.get_next(name=name)
+ for _ in range(self._iterations)
+ ]
+
+ @property
+ def initializer(self):
+ return self._dataset_iterator.initializer
+
+
+class PerIterationDataset(object):
+
+ def __init__(self, dataset, iterations):
+ self._dataset = dataset
+ self._iterations = iterations
+
+ def make_one_shot_iterator(self):
+ iterator = self._dataset.make_one_shot_iterator()
+ return MultiIterator(iterator, self._iterations)
+
+ def make_initializable_iterator(self):
+ iterator = self._dataset.make_initializable_iterator()
+ return MultiIterator(iterator, self._iterations)
+
+
class MapOutput(object):
"""Map can result in multiple outputs per device."""