aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py70
1 files changed, 53 insertions, 17 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 8ede28602f..f4a8e16c99 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -36,7 +36,9 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
+_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
class TPUClusterResolver(ClusterResolver):
@@ -58,6 +60,7 @@ class TPUClusterResolver(ClusterResolver):
if (self._tpu == compat.as_bytes('') or
self._tpu == compat.as_bytes('local') or
self._tpu.startswith(compat.as_bytes('/bns')) or
+ self._tpu.startswith(compat.as_bytes('localhost:')) or
self._tpu.startswith(compat.as_bytes('grpc://'))):
return False
return True
@@ -68,8 +71,8 @@ class TPUClusterResolver(ClusterResolver):
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
- def _gkeMaster():
- return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ def _gkeEndpoints():
+ return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _envVarFallback():
@@ -77,6 +80,10 @@ class TPUClusterResolver(ClusterResolver):
return os.environ[_DEFAULT_ENV_VARIABLE]
return None
+ @staticmethod
+ def _discoveryUrl():
+ return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
+
def __init__(self,
tpu=None,
zone=None,
@@ -85,7 +92,8 @@ class TPUClusterResolver(ClusterResolver):
coordinator_name=None,
coordinator_address=None,
credentials='default',
- service=None):
+ service=None,
+ discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
@@ -115,6 +123,11 @@ class TPUClusterResolver(ClusterResolver):
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
+ discovery_url: A URL template that points to the location of
+ the discovery service. It should have two parameters {api} and
+ {apiVersion} that when filled in produce an absolute URL to the
+ discovery document for that service. The environment variable
+ 'TPU_API_DISCOVERY_URL' will override this.
Raises:
ImportError: If the googleapiclient is not installed.
@@ -132,10 +145,13 @@ class TPUClusterResolver(ClusterResolver):
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
- tpu = self._gkeMaster()
+ tpu = self._gkeEndpoints()
else:
tpu = self._envVarFallback()
+ if tpu is None:
+ raise ValueError('Please provide a TPU Name to connect to.')
+
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
self._job_name = job_name
self._credentials = credentials
@@ -159,14 +175,22 @@ class TPUClusterResolver(ClusterResolver):
if service is None and should_resolve:
if not _GOOGLE_API_CLIENT_INSTALLED:
- raise ImportError('googleapiclient must be installed before using the '
- 'TPU cluster resolver. Execute: `pip install '
- '--upgrade google-api-python-client` to install with '
- 'pip.')
-
- self._service = discovery.build(
- 'tpu', 'v1alpha1',
- credentials=self._credentials)
+ raise ImportError('googleapiclient and oauth2client must be installed '
+ 'before using the TPU cluster resolver. Execute: '
+ '`pip install --upgrade google-api-python-client` '
+ 'and `pip install --upgrade oauth2client` to '
+ 'install with pip.')
+
+ final_discovery_url = self._discoveryUrl() or discovery_url
+ if final_discovery_url:
+ self._service = discovery.build(
+ 'tpu', 'v1alpha1',
+ credentials=self._credentials,
+ discoveryServiceUrl=final_discovery_url)
+ else:
+ self._service = discovery.build(
+ 'tpu', 'v1alpha1',
+ credentials=self._credentials)
else:
self._service = service
@@ -195,7 +219,7 @@ class TPUClusterResolver(ClusterResolver):
ValueError: If none of the TPUs specified exists.
"""
if not self._shouldResolve():
- return self._tpu
+ return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
job_tasks = self.cluster_spec().job_tasks(self._job_name)
if not job_tasks:
@@ -206,6 +230,10 @@ class TPUClusterResolver(ClusterResolver):
def get_master(self):
return self.master()
+ def get_job_name(self):
+ if self._shouldResolve():
+ return self._job_name
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
@@ -237,9 +265,13 @@ class TPUClusterResolver(ClusterResolver):
request = self._service.projects().locations().nodes().get(name=full_name)
response = request.execute()
+ if 'state' in response and response['state'] != 'READY':
+ raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
+ (compat.as_text(self._tpu), response['state']))
+
if 'health' in response and response['health'] != 'HEALTHY':
- raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
- response['health']))
+ raise RuntimeError('TPU "%s" is unhealthy: "%s"' %
+ (compat.as_text(self._tpu), response['health']))
if 'networkEndpoints' in response:
worker_list = [
@@ -257,8 +289,12 @@ class TPUClusterResolver(ClusterResolver):
# Case 3.
return None
# Case 2.
- cluster_spec = {self._job_name: [self._tpu[len(
- compat.as_bytes('grpc://')):]]}
+ cluster_spec = {
+ self._job_name: [
+ x[len(compat.as_bytes('grpc://')):]
+ for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))
+ ]
+ }
if self._coordinator_address:
# {1, 2}.a