aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-10-03 14:51:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 14:55:09 -0700
commit312e37cee391b0d207293d59d8882db3c8030f9d (patch)
tree9fc7534f3e7c8c527de00e609185575d2851844f /tensorflow/contrib/distribute
parentc1b3b0b9e041d82e80c2cdcc623a387753daf0b4 (diff)
Add a require_static_shapes argument to DistributionStrategy class. This allows us to identify if we need to set the drop_remainder option when creating Dataset objects.
PiperOrigin-RevId: 215633097
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index c3c7df3cd8..1d9e299b38 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -132,7 +132,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""
# TODO(sourabhbajaj): OneDeviceStrategy should be initialized with the
# master node fetched from the cluster resolver.
- super(TPUStrategy, self).__init__('/device:CPU:0')
+ super(TPUStrategy, self).__init__("/device:CPU:0")
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
@@ -152,6 +152,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
+ self._require_static_shapes = True
+
def _get_enqueue_op_per_host(self, host_id, iterator, input_shapes,
iterations):
"""Create an enqueue op for a single host identified using host_id.