diff options
Diffstat (limited to 'tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py')
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py | 70 |
1 files changed, 53 insertions, 17 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8ede28602f..f4a8e16c99 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -36,7 +36,9 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' +_ENDPOINTS_SEPARATOR = ',' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -58,6 +60,7 @@ class TPUClusterResolver(ClusterResolver): if (self._tpu == compat.as_bytes('') or self._tpu == compat.as_bytes('local') or self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('localhost:')) or self._tpu.startswith(compat.as_bytes('grpc://'))): return False return True @@ -68,8 +71,8 @@ class TPUClusterResolver(ClusterResolver): return _GKE_ENV_VARIABLE in os.environ @staticmethod - def _gkeMaster(): - return os.environ[_GKE_ENV_VARIABLE].split(',')[0] + def _gkeEndpoints(): + return os.environ[_GKE_ENV_VARIABLE] @staticmethod def _envVarFallback(): @@ -77,6 +80,10 @@ class TPUClusterResolver(ClusterResolver): return os.environ[_DEFAULT_ENV_VARIABLE] return None + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) + def __init__(self, tpu=None, zone=None, @@ -85,7 +92,8 @@ class TPUClusterResolver(ClusterResolver): coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -115,6 +123,11 @@ class TPUClusterResolver(ClusterResolver): service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -132,10 +145,13 @@ class TPUClusterResolver(ClusterResolver): # When using GKE with Cloud TPUs, the env variable will be set. if tpu is None: if in_gke: - tpu = self._gkeMaster() + tpu = self._gkeEndpoints() else: tpu = self._envVarFallback() + if tpu is None: + raise ValueError('Please provide a TPU Name to connect to.') + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes self._job_name = job_name self._credentials = credentials @@ -159,14 +175,22 @@ class TPUClusterResolver(ClusterResolver): if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: - raise ImportError('googleapiclient must be installed before using the ' - 'TPU cluster resolver. Execute: `pip install ' - '--upgrade google-api-python-client` to install with ' - 'pip.') - - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + raise ImportError('googleapiclient and oauth2client must be installed ' + 'before using the TPU cluster resolver. Execute: ' + '`pip install --upgrade google-api-python-client` ' + 'and `pip install --upgrade oauth2client` to ' + 'install with pip.') + + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service @@ -195,7 +219,7 @@ class TPUClusterResolver(ClusterResolver): ValueError: If none of the TPUs specified exists. """ if not self._shouldResolve(): - return self._tpu + return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0] job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: @@ -206,6 +230,10 @@ class TPUClusterResolver(ClusterResolver): def get_master(self): return self.master() + def get_job_name(self): + if self._shouldResolve(): + return self._job_name + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -237,9 +265,13 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (compat.as_text(self._tpu), response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': - raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, - response['health'])) + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % + (compat.as_text(self._tpu), response['health'])) if 'networkEndpoints' in response: worker_list = [ @@ -257,8 +289,12 @@ class TPUClusterResolver(ClusterResolver): # Case 3. return None # Case 2. - cluster_spec = {self._job_name: [self._tpu[len( - compat.as_bytes('grpc://')):]]} + cluster_spec = { + self._job_name: [ + x[len(compat.as_bytes('grpc://')):] + for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR)) + ] + } if self._coordinator_address: # {1, 2}.a |