diff options
author | Igor Saprykin <isaprykin@google.com> | 2018-04-19 19:11:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-19 19:14:33 -0700 |
commit | b001827146ff95c9e0ce5668c85d8cc2daf6b78d (patch) | |
tree | 244e55a5a99226cd8b44b48631d3368ebdef8952 /tensorflow/contrib/distribute/python/values.py | |
parent | 6e2df5e471295cd32f9887d76e6ddbf1b4e2a11a (diff) |
Support variable parameter structure in TPU distribution strategy.
TPUStrategy is added to a few more tests.
There appears to be an issue with the batch norm test in minimize_loss_test where the moving averages stay at 0. I'm trying to resolve that separately as the next CL.
PiperOrigin-RevId: 193610264
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 34 |
1 files changed, 28 insertions, 6 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 62016c3a78..8cb5276579 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -570,18 +570,36 @@ class PerDeviceDataset(object): dataset_iterator, self._devices, self._prefetch_on_device) +class PerIteration(object): + """Holds input for multiple iterations at once.""" + + def __init__(self, index): + self._index = index + + def get(self, iteration): + return array_ops.gather(self._index, iteration) + + def get_shape(self): + return self._index[-1][-1].get_shape() + + def get_dtype(self): + return self._index[-1][-1].dtype + + class MultiIterator(object): """Iterator that returns results of multiple get_next()s.""" - def __init__(self, dataset_iterator, iterations): + def __init__(self, dataset_iterator, iterations, batches_per_iteration): self._dataset_iterator = dataset_iterator self._iterations = iterations + self._batches_per_iteration = batches_per_iteration def get_next(self, name=None): - return [ + return PerIteration([[ self._dataset_iterator.get_next(name=name) - for _ in range(self._iterations) + for _ in range(self._batches_per_iteration) ] + for _ in range(self._iterations)]) @property def initializer(self): @@ -589,18 +607,22 @@ class MultiIterator(object): class PerIterationDataset(object): + """A dataset that returns MultiIterators.""" - def __init__(self, dataset, iterations): + def __init__(self, dataset, iterations, batches_per_iteration): self._dataset = dataset self._iterations = iterations + self._batches_per_iteration = batches_per_iteration def make_one_shot_iterator(self): iterator = self._dataset.make_one_shot_iterator() - return MultiIterator(iterator, self._iterations) + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) def make_initializable_iterator(self): iterator = self._dataset.make_initializable_iterator() - return MultiIterator(iterator, self._iterations) + return MultiIterator(iterator, self._iterations, + self._batches_per_iteration) class MapOutput(object): |