diff options
author | Frank Chen <frankchn@google.com> | 2018-01-17 14:29:30 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-17 14:33:10 -0800 |
commit | 6560e5d7307e5bba5c440a1d08f7e8b7692072d3 (patch) | |
tree | a5229c5d71f414f60a515eac5dd056ad51aca9a1 /tensorflow/contrib/cluster_resolver | |
parent | 393588ec4d64e53c2e6d3a6f34dcdc14332aeaeb (diff) |
Add functionality to auto-discover project and zone when they are not supplied to the TPUClusterResolver
PiperOrigin-RevId: 182270565
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py | 33 |
2 files changed, 58 insertions, 4 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c74da9cabd..2e75ac226e 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from six.moves.urllib.request import Request +from six.moves.urllib.request import urlopen + from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver from tensorflow.python.training.server_lib import ClusterSpec @@ -38,10 +42,16 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _requestComputeMetadata(self, path): + req = Request('http://metadata/computeMetadata/v1/%s' % path, + headers={'Metadata-Flavor': 'Google'}) + resp = urlopen(req) + return resp.read() + def __init__(self, - project, - zone, tpu_names, + zone=None, + project=None, job_name='tpu_worker', credentials='default', service=None): @@ -51,9 +61,13 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - project: Name of the GCP project containing Cloud TPUs - zone: Zone where the TPUs are located tpu_names: A list of names of the target Cloud TPUs. + zone: Zone where the TPUs are located. If omitted or empty, we will assume + that the zone of the TPU is the same as the zone of the GCE VM, which we + will try to discover from the GCE metadata service. + project: Name of the GCP project containing Cloud TPUs. If omitted or + empty, we will try to discover the project name of the GCE VM from the + GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client @@ -65,6 +79,13 @@ class TPUClusterResolver(ClusterResolver): ImportError: If the googleapiclient is not installed. """ + if not project: + project = self._requestComputeMetadata('/project/project-id') + + if not zone: + zone_path = self._requestComputeMetadata('/instance/zone') + zone = zone_path.split('/')[-1] + self._project = project self._zone = zone self._tpu_names = tpu_names diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index db7419be06..0c4730613a 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -48,6 +48,15 @@ class MockNodeClass(object): return MockRequestClass(name, self._tpu_map) +def mock_request_compute_metadata(cls, *args, **kwargs): + del cls, kwargs # Unused. + if args[0] == '/project/project-id': + return 'test-project' + elif args[0] == '/instance/zone': + return 'projects/test-project/locations/us-central1-c' + return '' + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -89,6 +98,30 @@ class TPUClusterResolverTest(test.TestCase): return mock_client + @mock.patch.object(TPUClusterResolver, + '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadata(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu_names=['test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { |