aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2018-09-06 10:50:35 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 10:55:17 -0700
commit9638524520d582e93a8038a89cd5cc62d719a3b6 (patch)
treee16b7b2c5032b15ac30ea07cbc36c199d6727272 /tensorflow
parentb9310932ce2120c8c36eb69bc135748fd3caf897 (diff)
Job name should be picked based on the cluster_spec
PiperOrigin-RevId: 211833041
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py3
2 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1ab150d74a..1056894f18 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -229,6 +229,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 4fb70ec685..6ba83976fc 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -310,7 +310,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
- return '/job:tpu_worker/task:%d/device:CPU:0' % (host_id,)
+ job_name = self._tpu_cluster_resolver.get_job_name() or 'tpu_worker'
+ return '/job:%s/task:%d/device:CPU:0' % (job_name, host_id)
def configure(self,
session_config=None,