diff options
Diffstat (limited to 'tensorflow/contrib/distribute/python/tpu_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/tpu_strategy.py | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py index 77fc56de36..6202a0750a 100644 --- a/tensorflow/contrib/distribute/python/tpu_strategy.py +++ b/tensorflow/contrib/distribute/python/tpu_strategy.py @@ -51,7 +51,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver): tpu_system_metadata_lib._query_tpu_system_metadata( master, cluster_def=cluster_def, - query_topology=True)) + query_topology=False)) return tpu_system_metadata @@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver): class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Experimental TPU distribution strategy implementation.""" - def __init__(self, tpu_cluster_resolver, steps_per_run): + def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None): """Initializes the TPUStrategy object. Args: @@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): metrics, summaries etc. This parameter is only used when Distribution Strategy is used with estimator or keras. + num_cores: Number of cores to use on the TPU. If None specified, then + auto-detect the cores and topology of the TPU system. """ # TODO(isaprykin): Generalize the defaults. They are currently tailored for # the unit test. @@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): self._tpu_cluster_resolver = tpu_cluster_resolver self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver) + self._num_cores_override = num_cores - # TODO(priyag): This should not be hardcoded here. - self._host = '/device:CPU:0' # TODO(sourabhbajaj): Remove this once performance of running one step # at a time is comparable to multiple steps. self.steps_per_run = steps_per_run + # TODO(frankchn): This should not be hardcoded here for pod purposes. + self._host = self.tpu_host_cpu_device(0) + def distribute_dataset(self, dataset_fn): # TODO(priyag): Perhaps distribute across cores here. return self._call_dataset_fn(dataset_fn) @@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): """Enqueue ops for one iteration.""" control_deps = [] sharded_inputs = [] + # TODO(sourabhbajaj): Add support for TPU pods with ops.device(self._host): for _ in range(self.num_towers): # Use control dependencies to ensure a deterministic ordering. @@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy): @property def num_towers(self): - return self._tpu_metadata.num_of_cores_per_host + return self._num_cores_override or self._tpu_metadata.num_cores + + def tpu_host_cpu_device(self, host_id): + if self._tpu_cluster_resolver.get_master() in ('', 'local'): + return '/replica:0/task:0/device:CPU:0' + return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id) + |