diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-10-03 14:51:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 14:55:09 -0700 |
commit | 312e37cee391b0d207293d59d8882db3c8030f9d (patch) | |
tree | 9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/contrib/distribute | |
parent | c1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff) |
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/tpu_strategy.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index c3c7df3cd8..1d9e299b38 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -132,7 +132,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """ # TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the # master node fetched from the cluster resolver. - super(TPUStrategy, self).__init__('/device:CPU:0') + super(TPUStrategy, self).__init__("/device:CPU:0") self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) @@ -152,6 +152,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + self._require_static_shapes = True + def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes, iterations): """Create an enqueue op for a single host identified using host_id. |