diff options
author | Philip Pham <phillypham@google.com> | 2018-09-24 10:13:17 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-24 10:21:04 -0700 |
commit | 77d56a08826826db3350968f19070434fa922995 (patch) | |
tree | 01b74bdd86575fb81b74f3456bac651403576d99 /tensorflow/contrib/distribute | |
parent | b1ca5f9d1f2def557ec2cea6c1ebccdfb5c6066a (diff) |
Implement required properties for TPU Strategy
These properties are necessary for the strategy to work with
`tf.estimator.train_and_evaluate`.
PiperOrigin-RevId: 214285957
Diffstat (limited to 'tensorflow/contrib/distribute')
-rw-r--r-- | tensorflow/contrib/distribute/python/tpu_strategy.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 6ba83976fc..ba2cc2e806 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -307,6 +307,22 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): def num_towers_per_host(self): return self._tpu_metadata.num_of_cores_per_host + @property + def between_graph(self): + return False + + @property + def should_init(self): + return True + + @property + def should_checkpoint(self): + return True + + @property + def should_save_summary(self): + return True + def get_host_cpu_device(self, host_id): if self._tpu_cluster_resolver.get_master() in ('', 'local'): return '/replica:0/task:0/device:CPU:0' @@ -324,4 +340,3 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): cluster_spec = self._tpu_cluster_resolver.cluster_spec() if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) - |