diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-10-03 14:51:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 14:55:09 -0700 |
commit | 312e37cee391b0d207293d59d8882db3c8030f9d (patch) | |
tree | 9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/python/keras | |
parent | c1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff) |
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r-- | tensorflow/python/keras/engine/training.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 85233de9b1..d81bd83f7f 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -814,6 +814,9 @@ class Model(Network): x_shape = first_x_value.shape if batch_size is None: batch_size = x_shape[0] // steps + # We need to use the drop_remainder argument to allow for a static + # input shape which is required for TPUs. + drop_remainder = self._distribution_strategy.require_static_shapes if y is not None: var_x = distributed_training_utils.get_var_for_numpy( self._distribution_strategy, x) @@ -824,9 +827,7 @@ class Model(Network): # TODO(anjalisridhar): What should the buffer size be? x = x.shuffle(10000) x = x.repeat() - # We need to use the drop_remainder argument to allow for a static - # input shape which is required for TPUs. - x = x.batch(batch_size, drop_remainder=True) + x = x.batch(batch_size, drop_remainder=drop_remainder) y = None else: # This case is for the predict call where the dataset only contains @@ -838,9 +839,7 @@ class Model(Network): self._distribution_strategy, x) x = dataset_ops.Dataset.from_tensor_slices(var_x) x = x.repeat() - # We need to use the drop_remainder argument to allow for a static - # input shape which is required for TPUs. - x = x.batch(batch_size, drop_remainder=True) + x = x.batch(batch_size, drop_remainder=drop_remainder) # TODO(anjalisridhar): Can we use the iterator and getnext op cache? # We require users to pass Datasets since we distribute the dataset across |