aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2018-01-17 14:29:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-17 14:33:10 -0800
commit6560e5d7307e5bba5c440a1d08f7e8b7692072d3 (patch)
treea5229c5d71f414f60a515eac5dd056ad51aca9a1 /tensorflow/contrib/cluster_resolver
parent393588ec4d64e53c2e6d3a6f34dcdc14332aeaeb (diff)
Add functionality to auto-discover project and zone when they are not supplied to the TPUClusterResolver
PiperOrigin-RevId: 182270565
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py29
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py33
2 files changed, 58 insertions, 4 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index c74da9cabd..2e75ac226e 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -18,6 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+from six.moves.urllib.request import Request
+from six.moves.urllib.request import urlopen
+
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
from tensorflow.python.training.server_lib import ClusterSpec
@@ -38,10 +42,16 @@ class TPUClusterResolver(ClusterResolver):
Cloud Platform project.
"""
+ def _requestComputeMetadata(self, path):
+ req = Request('http://metadata/computeMetadata/v1/%s' % path,
+ headers={'Metadata-Flavor': 'Google'})
+ resp = urlopen(req)
+ return resp.read()
+
def __init__(self,
- project,
- zone,
tpu_names,
+ zone=None,
+ project=None,
job_name='tpu_worker',
credentials='default',
service=None):
@@ -51,9 +61,13 @@ class TPUClusterResolver(ClusterResolver):
for the IP addresses and ports of each Cloud TPU listed.
Args:
- project: Name of the GCP project containing Cloud TPUs
- zone: Zone where the TPUs are located
tpu_names: A list of names of the target Cloud TPUs.
+ zone: Zone where the TPUs are located. If omitted or empty, we will assume
+ that the zone of the TPU is the same as the zone of the GCE VM, which we
+ will try to discover from the GCE metadata service.
+ project: Name of the GCP project containing Cloud TPUs. If omitted or
+ empty, we will try to discover the project name of the GCE VM from the
+ GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
@@ -65,6 +79,13 @@ class TPUClusterResolver(ClusterResolver):
ImportError: If the googleapiclient is not installed.
"""
+ if not project:
+ project = self._requestComputeMetadata('/project/project-id')
+
+ if not zone:
+ zone_path = self._requestComputeMetadata('/instance/zone')
+ zone = zone_path.split('/')[-1]
+
self._project = project
self._zone = zone
self._tpu_names = tpu_names
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index db7419be06..0c4730613a 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -48,6 +48,15 @@ class MockNodeClass(object):
return MockRequestClass(name, self._tpu_map)
+def mock_request_compute_metadata(cls, *args, **kwargs):
+ del cls, kwargs # Unused.
+ if args[0] == '/project/project-id':
+ return 'test-project'
+ elif args[0] == '/instance/zone':
+ return 'projects/test-project/locations/us-central1-c'
+ return ''
+
+
class TPUClusterResolverTest(test.TestCase):
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
@@ -89,6 +98,30 @@ class TPUClusterResolverTest(test.TestCase):
return mock_client
+ @mock.patch.object(TPUClusterResolver,
+ '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testRetrieveProjectAndZoneFromMetadata(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project=None,
+ zone=None,
+ tpu_names=['test-tpu-1'],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {