aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-10-03 14:51:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 14:55:09 -0700
commit312e37cee391b0d207293d59d8882db3c8030f9d (patch)
tree9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/python/keras
parentc1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff)
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/python/keras')
-rw-r--r--tensorflow/python/keras/engine/training.py11
1 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 85233de9b1..d81bd83f7f 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -814,6 +814,9 @@ class Model(Network):
x_shape = first_x_value.shape
if batch_size is None:
batch_size = x_shape[0] // steps
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ drop_remainder = self._distribution_strategy.require_static_shapes
if y is not None:
var_x = distributed_training_utils.get_var_for_numpy(
self._distribution_strategy, x)
@@ -824,9 +827,7 @@ class Model(Network):
# TODO(anjalisridhar): What should the buffer size be?
x = x.shuffle(10000)
x = x.repeat()
- # We need to use the drop_remainder argument to allow for a static
- # input shape which is required for TPUs.
- x = x.batch(batch_size, drop_remainder=True)
+ x = x.batch(batch_size, drop_remainder=drop_remainder)
y = None
else:
# This case is for the predict call where the dataset only contains
@@ -838,9 +839,7 @@ class Model(Network):
self._distribution_strategy, x)
x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.repeat()
- # We need to use the drop_remainder argument to allow for a static
- # input shape which is required for TPUs.
- x = x.batch(batch_size, drop_remainder=True)
+ x = x.batch(batch_size, drop_remainder=drop_remainder)
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across