aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute
diff options
context:
space:
mode:
authorGravatar Philip Pham <phillypham@google.com>2018-09-24 10:13:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-24 10:21:04 -0700
commit77d56a08826826db3350968f19070434fa922995 (patch)
tree01b74bdd86575fb81b74f3456bac651403576d99 /tensorflow/contrib/distribute
parentb1ca5f9d1f2def557ec2cea6c1ebccdfb5c6066a (diff)
Implement required properties for TPU Strategy
These properties are necessary for the strategy to work with `tf.estimator.train_and_evaluate`. PiperOrigin-RevId: 214285957
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 6ba83976fc..ba2cc2e806 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -307,6 +307,22 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def num_towers_per_host(self):
return self._tpu_metadata.num_of_cores_per_host
+ @property
+ def between_graph(self):
+ return False
+
+ @property
+ def should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return True
+
+ @property
+ def should_save_summary(self):
+ return True
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
@@ -324,4 +340,3 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
cluster_spec = self._tpu_cluster_resolver.cluster_spec()
if cluster_spec:
session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
-