aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/tpu_strategy.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/distribute/python/tpu_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 77fc56de36..6202a0750a 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -51,7 +51,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
tpu_system_metadata_lib._query_tpu_system_metadata(
master,
cluster_def=cluster_def,
- query_topology=True))
+ query_topology=False))
return tpu_system_metadata
@@ -59,7 +59,7 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
- def __init__(self, tpu_cluster_resolver, steps_per_run):
+ def __init__(self, tpu_cluster_resolver, steps_per_run, num_cores=None):
"""Initializes the TPUStrategy object.
Args:
@@ -70,6 +70,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
metrics, summaries etc.
This parameter is only used when Distribution Strategy is used with
estimator or keras.
+ num_cores: Number of cores to use on the TPU. If None specified, then
+ auto-detect the cores and topology of the TPU system.
"""
# TODO(isaprykin): Generalize the defaults. They are currently tailored for
# the unit test.
@@ -77,13 +79,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
+ self._num_cores_override = num_cores
- # TODO(priyag): This should not be hardcoded here.
- self._host = '/device:CPU:0'
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
+ # TODO(frankchn): This should not be hardcoded here for pod purposes.
+ self._host = self.tpu_host_cpu_device(0)
+
def distribute_dataset(self, dataset_fn):
# TODO(priyag): Perhaps distribute across cores here.
return self._call_dataset_fn(dataset_fn)
@@ -106,6 +110,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Enqueue ops for one iteration."""
control_deps = []
sharded_inputs = []
+ # TODO(sourabhbajaj): Add support for TPU pods
with ops.device(self._host):
for _ in range(self.num_towers):
# Use control dependencies to ensure a deterministic ordering.
@@ -258,4 +263,10 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
@property
def num_towers(self):
- return self._tpu_metadata.num_of_cores_per_host
+ return self._num_cores_override or self._tpu_metadata.num_cores
+
+ def tpu_host_cpu_device(self, host_id):
+ if self._tpu_cluster_resolver.get_master() in ('', 'local'):
+ return '/replica:0/task:0/device:CPU:0'
+ return '/job:%s/task:%d/device:CPU:0' % ('tpu_worker', host_id)
+