diff options
author | 2018-05-21 18:14:30 -0700 | |
---|---|---|
committer | 2018-05-21 18:17:11 -0700 | |
commit | 1d5c44cd876377eb296cee22567228ea6f72a7ac (patch) | |
tree | 41f963255f2cc8a7b5f1bceacce45317855809ec /tensorflow/contrib/cluster_resolver | |
parent | 7c3cd0842a41aac47069dcf14567b88c32ea7b28 (diff) |
Adds support for specifying a discovery_service_url (via either a parameter or an environment variable) within TPUClusterResolver
PiperOrigin-RevId: 197494335
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py | 4 |
2 files changed, 26 insertions, 4 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index 8ede28602f..880fca4ea6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -37,6 +37,7 @@ except ImportError: _GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' _DEFAULT_ENV_VARIABLE = 'TPU_NAME' +_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL' class TPUClusterResolver(ClusterResolver): @@ -77,6 +78,10 @@ class TPUClusterResolver(ClusterResolver): return os.environ[_DEFAULT_ENV_VARIABLE] return None + @staticmethod + def _discoveryUrl(): + return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE) + def __init__(self, tpu=None, zone=None, @@ -85,7 +90,8 @@ class TPUClusterResolver(ClusterResolver): coordinator_name=None, coordinator_address=None, credentials='default', - service=None): + service=None, + discovery_url=None): """Creates a new TPUClusterResolver object. The ClusterResolver will then use the parameters to query the Cloud TPU APIs @@ -115,6 +121,11 @@ class TPUClusterResolver(ClusterResolver): service: The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored. + discovery_url: A URL template that points to the location of + the discovery service. It should have two parameters {api} and + {apiVersion} that when filled in produce an absolute URL to the + discovery document for that service. The environment variable + 'TPU_API_DISCOVERY_URL' will override this. Raises: ImportError: If the googleapiclient is not installed. @@ -164,9 +175,16 @@ class TPUClusterResolver(ClusterResolver): '--upgrade google-api-python-client` to install with ' 'pip.') - self._service = discovery.build( - 'tpu', 'v1alpha1', - credentials=self._credentials) + final_discovery_url = self._discoveryUrl() or discovery_url + if final_discovery_url: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials, + discoveryServiceUrl=final_discovery_url) + else: + self._service = discovery.build( + 'tpu', 'v1alpha1', + credentials=self._credentials) else: self._service = service diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5b3f9be5a1..5fac55fd02 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -367,6 +367,10 @@ class TPUClusterResolverTest(test.TestCase): compat.as_bytes(TPUClusterResolver._gkeMaster())) del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] + def testDiscoveryUrl(self): + os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}' + self.assertEqual('https://{api}.internal/{apiVersion}', + TPUClusterResolver._discoveryUrl()) if __name__ == '__main__': test.main() |