aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-05-10 17:32:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 17:35:29 -0700
commit03d770b78d4cb799ce7945adcbc8ac10fe6f4d38 (patch)
treed7ed6ab06136b76909aea92fed129888f2197057 /tensorflow/contrib/cluster_resolver
parentd774abfe3850b41b3883dd26e4f9c945c0ababb9 (diff)
[TPU]: If the $TPU_NAME env var is set, fallback to that.
PiperOrigin-RevId: 196196939
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 1403483d28..8ede28602f 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -36,6 +36,7 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
class TPUClusterResolver(ClusterResolver):
@@ -70,6 +71,12 @@ class TPUClusterResolver(ClusterResolver):
def _gkeMaster():
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ @staticmethod
+ def _envVarFallback():
+ if _DEFAULT_ENV_VARIABLE in os.environ:
+ return os.environ[_DEFAULT_ENV_VARIABLE]
+ return None
+
def __init__(self,
tpu=None,
zone=None,
@@ -123,8 +130,11 @@ class TPUClusterResolver(ClusterResolver):
in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
- if tpu is None and in_gke:
- tpu = self._gkeMaster()
+ if tpu is None:
+ if in_gke:
+ tpu = self._gkeMaster()
+ else:
+ tpu = self._envVarFallback()
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name