aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-10-03 14:51:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 14:55:09 -0700
commit312e37cee391b0d207293d59d8882db3c8030f9d (patch)
tree9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/python/training
parentc1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff)
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/distribute.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index a92a1bdee7..b3f3c29b2f 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -436,6 +436,9 @@ class DistributionStrategy(object):
def __init__(self):
self._default_device = None
+ # This property is used to determine if we should set drop_remainder=True
+ # when creating Datasets from numpy array inputs.
+ self._require_static_shapes = False
def scope(self):
"""Returns a context manager selecting this DistributionStrategy as current.
@@ -899,6 +902,10 @@ class DistributionStrategy(object):
raise NotImplementedError("must be implemented in descendants")
@property
+ def require_static_shapes(self):
+ return self._require_static_shapes
+
+ @property
def num_towers(self):
"""Returns number of towers, for purposes of averaging across towers."""
raise NotImplementedError("must be implemented in descendants")