aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2018-05-21 18:14:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-21 18:17:11 -0700
commit1d5c44cd876377eb296cee22567228ea6f72a7ac (patch)
tree41f963255f2cc8a7b5f1bceacce45317855809ec /tensorflow/contrib/cluster_resolver
parent7c3cd0842a41aac47069dcf14567b88c32ea7b28 (diff)
Adds support for specifying a discovery_service_url (via either a parameter or an environment variable) within TPUClusterResolver
PiperOrigin-RevId: 197494335
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py26
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py4
2 files changed, 26 insertions, 4 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 8ede28602f..880fca4ea6 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -37,6 +37,7 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
+_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
class TPUClusterResolver(ClusterResolver):
@@ -77,6 +78,10 @@ class TPUClusterResolver(ClusterResolver):
return os.environ[_DEFAULT_ENV_VARIABLE]
return None
+ @staticmethod
+ def _discoveryUrl():
+ return os.environ.get(_DISCOVERY_SERVICE_URL_ENV_VARIABLE)
+
def __init__(self,
tpu=None,
zone=None,
@@ -85,7 +90,8 @@ class TPUClusterResolver(ClusterResolver):
coordinator_name=None,
coordinator_address=None,
credentials='default',
- service=None):
+ service=None,
+ discovery_url=None):
"""Creates a new TPUClusterResolver object.
The ClusterResolver will then use the parameters to query the Cloud TPU APIs
@@ -115,6 +121,11 @@ class TPUClusterResolver(ClusterResolver):
service: The GCE API object returned by the googleapiclient.discovery
function. If you specify a custom service object, then the credentials
parameter will be ignored.
+ discovery_url: A URL template that points to the location of
+ the discovery service. It should have two parameters {api} and
+ {apiVersion} that when filled in produce an absolute URL to the
+ discovery document for that service. The environment variable
+ 'TPU_API_DISCOVERY_URL' will override this.
Raises:
ImportError: If the googleapiclient is not installed.
@@ -164,9 +175,16 @@ class TPUClusterResolver(ClusterResolver):
'--upgrade google-api-python-client` to install with '
'pip.')
- self._service = discovery.build(
- 'tpu', 'v1alpha1',
- credentials=self._credentials)
+ final_discovery_url = self._discoveryUrl() or discovery_url
+ if final_discovery_url:
+ self._service = discovery.build(
+ 'tpu', 'v1alpha1',
+ credentials=self._credentials,
+ discoveryServiceUrl=final_discovery_url)
+ else:
+ self._service = discovery.build(
+ 'tpu', 'v1alpha1',
+ credentials=self._credentials)
else:
self._service = service
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 5b3f9be5a1..5fac55fd02 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -367,6 +367,10 @@ class TPUClusterResolverTest(test.TestCase):
compat.as_bytes(TPUClusterResolver._gkeMaster()))
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
+ def testDiscoveryUrl(self):
+ os.environ['TPU_API_DISCOVERY_URL'] = 'https://{api}.internal/{apiVersion}'
+ self.assertEqual('https://{api}.internal/{apiVersion}',
+ TPUClusterResolver._discoveryUrl())
if __name__ == '__main__':
test.main()