diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-10-03 14:51:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 14:55:09 -0700 |
commit | 312e37cee391b0d207293d59d8882db3c8030f9d (patch) | |
tree | 9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/python/training | |
parent | c1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff) |
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/distribute.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index a92a1bdee7..b3f3c29b2f 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -436,6 +436,9 @@ class DistributionStrategy(object): def __init__(self): self._default_device = None + # This property is used to determine if we should set drop_remainder=True + # when creating Datasets from numpy array inputs. + self._require_static_shapes = False def scope(self): """Returns a context manager selecting this DistributionStrategy as current. @@ -899,6 +902,10 @@ class DistributionStrategy(object): raise NotImplementedError("must be implemented in descendants") @property + def require_static_shapes(self): + return self._require_static_shapes + + @property def num_towers(self): """Returns number of towers, for purposes of averaging across towers.""" raise NotImplementedError("must be implemented in descendants") |