aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2018-04-19 19:11:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-19 19:14:33 -0700
commitb001827146ff95c9e0ce5668c85d8cc2daf6b78d (patch)
tree244e55a5a99226cd8b44b48631d3368ebdef8952 /tensorflow/contrib/distribute/python/values.py
parent6e2df5e471295cd32f9887d76e6ddbf1b4e2a11a (diff)
Support variable parameter structure in TPU distribution strategy.
TPUStrategy is added to a few more tests. There appears to be an issue with the batch norm test in minimize_loss_test where the moving averages stay at 0. I'm trying to resolve that separately as the next CL. PiperOrigin-RevId: 193610264
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 62016c3a78..8cb5276579 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -570,18 +570,36 @@ class PerDeviceDataset(object):
dataset_iterator, self._devices, self._prefetch_on_device)
+class PerIteration(object):
+ """Holds input for multiple iterations at once."""
+
+ def __init__(self, index):
+ self._index = index
+
+ def get(self, iteration):
+ return array_ops.gather(self._index, iteration)
+
+ def get_shape(self):
+ return self._index[-1][-1].get_shape()
+
+ def get_dtype(self):
+ return self._index[-1][-1].dtype
+
+
class MultiIterator(object):
"""Iterator that returns results of multiple get_next()s."""
- def __init__(self, dataset_iterator, iterations):
+ def __init__(self, dataset_iterator, iterations, batches_per_iteration):
self._dataset_iterator = dataset_iterator
self._iterations = iterations
+ self._batches_per_iteration = batches_per_iteration
def get_next(self, name=None):
- return [
+ return PerIteration([[
self._dataset_iterator.get_next(name=name)
- for _ in range(self._iterations)
+ for _ in range(self._batches_per_iteration)
]
+ for _ in range(self._iterations)])
@property
def initializer(self):
@@ -589,18 +607,22 @@ class MultiIterator(object):
class PerIterationDataset(object):
+ """A dataset that returns MultiIterators."""
- def __init__(self, dataset, iterations):
+ def __init__(self, dataset, iterations, batches_per_iteration):
self._dataset = dataset
self._iterations = iterations
+ self._batches_per_iteration = batches_per_iteration
def make_one_shot_iterator(self):
iterator = self._dataset.make_one_shot_iterator()
- return MultiIterator(iterator, self._iterations)
+ return MultiIterator(iterator, self._iterations,
+ self._batches_per_iteration)
def make_initializable_iterator(self):
iterator = self._dataset.make_initializable_iterator()
- return MultiIterator(iterator, self._iterations)
+ return MultiIterator(iterator, self._iterations,
+ self._batches_per_iteration)
class MapOutput(object):