aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-02-26 10:54:31 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-26 11:00:19 -0800
commit5a657b47f724b96730f764d3fb21c89e342e9c35 (patch)
treed2588033f82402f6d044c2a77cc68afc81219889 /tensorflow/contrib/cluster_resolver
parent7735b2db761fba6e76c170066b2e5c3b7f10688b (diff)
Integrate ClusterResolvers with TPUEstimator.
PiperOrigin-RevId: 187047094
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/BUILD1
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py23
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py2
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py3
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py150
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py226
6 files changed, 314 insertions, 91 deletions
diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD
index 6b03df2b8e..1a124eca36 100644
--- a/tensorflow/contrib/cluster_resolver/BUILD
+++ b/tensorflow/contrib/cluster_resolver/BUILD
@@ -110,5 +110,6 @@ tf_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
+ grpc_enabled = True,
main = "python/training/tpu_cluster_resolver_test.py",
)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
index b04822fa9d..1c480b2513 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py
@@ -53,11 +53,16 @@ class ClusterResolver(object):
raise NotImplementedError(
'cluster_spec is not implemented for {}.'.format(self))
+ @abc.abstractmethod
+ def master(self):
+ """..."""
+ raise NotImplementedError('master is not implemented for {}.'.format(self))
+
class SimpleClusterResolver(ClusterResolver):
"""Simple implementation of ClusterResolver that accepts a ClusterSpec."""
- def __init__(self, cluster_spec):
+ def __init__(self, cluster_spec, master=''):
"""Creates a SimpleClusterResolver from a ClusterSpec."""
super(SimpleClusterResolver, self).__init__()
@@ -65,10 +70,18 @@ class SimpleClusterResolver(ClusterResolver):
raise TypeError('cluster_spec must be a ClusterSpec.')
self._cluster_spec = cluster_spec
+ if not isinstance(master, str):
+ raise TypeError('master must be a string.')
+ self._master = master
+
def cluster_spec(self):
"""Returns the ClusterSpec passed into the constructor."""
return self._cluster_spec
+ def master(self):
+ """Returns the master address to use when creating a session."""
+ return self._master
+
class UnionClusterResolver(ClusterResolver):
"""Performs a union on underlying ClusterResolvers.
@@ -87,9 +100,13 @@ class UnionClusterResolver(ClusterResolver):
Raises:
TypeError: If any argument is not a subclass of `ClusterResolvers`.
+ ValueError: If there are no arguments passed.
"""
super(UnionClusterResolver, self).__init__()
+ if not args:
+ raise ValueError('At least one ClusterResolver is required.')
+
for cluster_resolver in args:
if not isinstance(cluster_resolver, ClusterResolver):
raise TypeError('All arguments must be a sub-class of '
@@ -169,3 +186,7 @@ class UnionClusterResolver(ClusterResolver):
merged_cluster[job_name].update(task_dict)
return ClusterSpec(merged_cluster)
+
+ def master(self):
+ """master returns the master address from the first cluster resolver."""
+ return self._cluster_resolvers[0].master()
diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
index dbfb77723c..d9c97d53eb 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py
@@ -234,5 +234,7 @@ class UnionClusterResolverTest(test.TestCase):
self._verifyClusterSpecEquality(cluster_spec, expected_proto)
+# TODO(saeta): Include tests for master resolution
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py
index d6f2eced93..3f58241289 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py
@@ -134,3 +134,6 @@ class GceClusterResolver(ClusterResolver):
worker_list.sort()
return ClusterSpec({self._job_name: worker_list})
+
+ def master(self):
+ return ''
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index a6a6e642e4..aeccf4c06b 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -23,7 +23,8 @@ from six.moves.urllib.request import Request
from six.moves.urllib.request import urlopen
from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver
-from tensorflow.python.training.server_lib import ClusterSpec
+from tensorflow.python.training import server_lib
+from tensorflow.python.util import compat
_GOOGLE_API_CLIENT_INSTALLED = True
try:
@@ -46,13 +47,23 @@ class TPUClusterResolver(ClusterResolver):
req = Request('http://metadata/computeMetadata/v1/%s' % path,
headers={'Metadata-Flavor': 'Google'})
resp = urlopen(req)
- return resp.read()
+ return compat.as_bytes(resp.read())
+
+ def _shouldResolve(self):
+ if (self._tpu == compat.as_bytes('') or
+ self._tpu == compat.as_bytes('local') or
+ self._tpu.startswith(compat.as_bytes('/bns')) or
+ self._tpu.startswith(compat.as_bytes('grpc://'))):
+ return False
+ return True
def __init__(self,
- tpu_names,
+ tpu,
zone=None,
project=None,
- job_name='tpu_worker',
+ job_name='worker',
+ coordinator_name='coordinator',
+ coordinator_address=None,
credentials='default',
service=None):
"""Creates a new TPUClusterResolver object.
@@ -61,7 +72,11 @@ class TPUClusterResolver(ClusterResolver):
for the IP addresses and ports of each Cloud TPU listed.
Args:
- tpu_names: A list of names of the target Cloud TPUs.
+ tpu: Either a string, or a list of strings corresponding to the TPUs to
+ use. If the single string is the empty string, the string 'local', or a
+ string that begins with 'grpc://' or '/bns', then it is assumed to not
+ correspond with a Cloud TPU and will instead be passed as the session
+ master and no ClusterSpec propagation will be done.
zone: Zone where the TPUs are located. If omitted or empty, we will assume
that the zone of the TPU is the same as the zone of the GCE VM, which we
will try to discover from the GCE metadata service.
@@ -69,6 +84,12 @@ class TPUClusterResolver(ClusterResolver):
empty, we will try to discover the project name of the GCE VM from the
GCE metadata service.
job_name: Name of the TensorFlow job the TPUs belong to.
+ coordinator_name: The name to use for the coordinator. Set to None if the
+ coordinator should not be included in the computed ClusterSpec.
+ coordinator_address: The address of the coordinator (typically an ip:port
+ pair). If set to None, a TF server will be started. If coordinator_name
+ is None, a TF server will not be started even if coordinator_address is
+ None.
credentials: GCE Credentials. If None, then we use default credentials
from the oauth2client
service: The GCE API object returned by the googleapiclient.discovery
@@ -77,26 +98,36 @@ class TPUClusterResolver(ClusterResolver):
Raises:
ImportError: If the googleapiclient is not installed.
+ ValueError: If no TPUs are specified.
"""
+ if isinstance(tpu, list):
+ if not tpu:
+ raise ValueError('At least one TPU must be specified.')
+ if len(tpu) != 1:
+ raise NotImplementedError(
+ 'Using multiple TPUs in a single session is not yet implemented')
+ tpu = tpu[0]
+ self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
+ self._job_name = job_name
+ self._credentials = credentials
- if not project:
- project = self._requestComputeMetadata('/project/project-id')
+ should_resolve = self._shouldResolve()
- if not zone:
- zone_path = self._requestComputeMetadata('/instance/zone')
+ if not project and should_resolve:
+ project = self._requestComputeMetadata('project/project-id')
+
+ if not zone and should_resolve:
+ zone_path = self._requestComputeMetadata('instance/zone')
zone = zone_path.split('/')[-1]
self._project = project
self._zone = zone
- self._tpu_names = tpu_names
- self._job_name = job_name
- self._credentials = credentials
- if credentials == 'default':
+ if credentials == 'default' and should_resolve:
if _GOOGLE_API_CLIENT_INSTALLED:
self._credentials = GoogleCredentials.get_application_default()
- if service is None:
+ if service is None and should_resolve:
if not _GOOGLE_API_CLIENT_INSTALLED:
raise ImportError('googleapiclient must be installed before using the '
'TPU cluster resolver')
@@ -107,25 +138,41 @@ class TPUClusterResolver(ClusterResolver):
else:
self._service = service
- def get_master(self):
- """Get the ClusterSpec grpc master path.
+ self._coordinator_name = coordinator_name
+ if coordinator_name and not coordinator_address and should_resolve:
+ self._start_local_server()
+ else:
+ self._coordinator_address = coordinator_address
+
+ def master(self):
+ """Get the Master string to be used for the session.
+
+ In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of
+ first instance in the ClusterSpec returned by the cluster_spec function.
- This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the
- ClusterSpec returned by the cluster_spec function. This is suitable for use
- for the `master` argument in tf.Session() when you are using one TPU.
+ If a non-TPU name is used when constructing a TPUClusterResolver, that will
+ be returned instead (e.g. If the tpus argument's value when constructing
+ this TPUClusterResolver was 'grpc://10.240.1.2:8470',
+ 'grpc://10.240.1.2:8470' will be returned).
Returns:
- string, the grpc path of the first instance in the ClusterSpec.
+ string, the connection string to use when creating a session.
Raises:
ValueError: If none of the TPUs specified exists.
"""
+ if not self._shouldResolve():
+ return self._tpu
+
job_tasks = self.cluster_spec().job_tasks(self._job_name)
if not job_tasks:
raise ValueError('No TPUs exists with the specified names exist.')
return 'grpc://' + job_tasks[0]
+ def get_master(self):
+ return self.master()
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
@@ -134,17 +181,54 @@ class TPUClusterResolver(ClusterResolver):
Returns:
A ClusterSpec containing host information returned from Cloud TPUs.
- """
- worker_list = []
-
- for tpu_name in self._tpu_names:
- full_name = 'projects/%s/locations/%s/nodes/%s' % (
- self._project, self._zone, tpu_name)
- request = self._service.projects().locations().nodes().get(name=full_name)
- response = request.execute()
- if 'health' in response and response['health'] == 'HEALTHY':
- instance_url = '%s:%s' % (response['ipAddress'], response['port'])
- worker_list.append(instance_url)
-
- return ClusterSpec({self._job_name: worker_list})
+ Raises:
+ RuntimeError: If the provided TPU is not healthy.
+ """
+ if not self._shouldResolve():
+ return server_lib.ClusterSpec({})
+
+ full_name = 'projects/%s/locations/%s/nodes/%s' % (
+ self._project, self._zone, compat.as_text(self._tpu))
+ request = self._service.projects().locations().nodes().get(name=full_name)
+ response = request.execute()
+
+ if 'health' in response and response['health'] != 'HEALTHY':
+ raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
+ response['health']))
+
+ if 'networkEndpoints' in response:
+ worker_list = [
+ '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
+ for endpoint in response['networkEndpoints']
+ ]
+ else:
+ # Fall back to the deprecated response format
+ instance_url = '%s:%s' % (response['ipAddress'], response['port'])
+ worker_list = [instance_url]
+
+ cluster_spec = {self._job_name: worker_list}
+
+ if self._coordinator_address:
+ cluster_spec[self._coordinator_name] = [self._coordinator_address]
+
+ return server_lib.ClusterSpec(cluster_spec)
+
+ def _start_local_server(self):
+ address = self._requestComputeMetadata('instance/network-interfaces/0/ip')
+ self._server = server_lib.Server(
+ {
+ 'local': ['0.0.0.0:0']
+ }, protocol='grpc', config=None, start=True)
+ # self._server.target is of the form: grpc://ipaddress:port
+ target = compat.as_bytes(self._server.target)
+ splits = target.split(compat.as_bytes(':'))
+ assert len(splits) == 3, self._server.target
+ assert splits[0] == compat.as_bytes('grpc'), self._server.target
+ self._coordinator_port = compat.as_text(splits[2])
+ self._coordinator_address = '%s:%s' % (
+ address, compat.as_text(self._coordinator_port))
+
+ def __deepcopy__(self, memo):
+ # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy.
+ return self
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 4fd34629cf..6b4a155152 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -21,7 +21,7 @@ from __future__ import print_function
from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-
+from tensorflow.python.util import compat
mock = test.mock
@@ -50,10 +50,12 @@ class MockNodeClass(object):
def mock_request_compute_metadata(cls, *args, **kwargs):
del cls, kwargs # Unused.
- if args[0] == '/project/project-id':
+ if args[0] == 'project/project-id':
return 'test-project'
- elif args[0] == '/instance/zone':
+ elif args[0] == 'instance/zone':
return 'projects/test-project/locations/us-central1-c'
+ elif args[0] == 'instance/network-interfaces/0/ip':
+ return '10.128.1.2'
return ''
@@ -113,17 +115,26 @@ class TPUClusterResolverTest(test.TestCase):
tpu_cluster_resolver = TPUClusterResolver(
project=None,
zone=None,
- tpu_names=['test-tpu-1'],
+ tpu=['test-tpu-1'],
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
- job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
- """
- self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+ job {
+ name: 'coordinator'
+ tasks { key: 0 value: '10.128.1.2:%s' }
+ }
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.1.2.3:8470' }
+ }
+ """ % tpu_cluster_resolver._coordinator_port
+ self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
- def testSimpleSuccessfulRetrieval(self):
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
@@ -133,116 +144,217 @@ class TPUClusterResolverTest(test.TestCase):
}
tpu_cluster_resolver = TPUClusterResolver(
- project='test-project',
- zone='us-central1-c',
- tpu_names=['test-tpu-1'],
+ project=None,
+ zone=None,
+ tpu=['test-tpu-1'],
+ coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
- job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } }
+ job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
- def testMultipleSuccessfulRetrieval(self):
+ def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
'port': '8470',
'health': 'HEALTHY'
- },
- 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
- 'ipAddress': '10.4.5.6',
- 'port': '8470',
- 'health': 'HEALTHY'
}
}
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
- tpu_names=['test-tpu-2', 'test-tpu-1'],
+ tpu=['test-tpu-1'],
+ coordinator_address='10.128.1.5:10203',
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
- job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' }
- tasks { key: 1 value: '10.1.2.3:8470' } }
+ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
+ job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
- def testHealthyTpuNodeRetrieval(self):
+ def testNewNetworkEndpointFormat(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
- 'ipAddress': '10.1.2.3',
- 'port': '8470',
- 'health': 'HEALTHY'
- },
- 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
- 'ipAddress': '10.4.5.6',
- 'port': '8470',
- },
- 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': {
- 'ipAddress': '10.7.8.9',
- 'port': '8470',
- 'health': 'UNHEALTHY'
+ 'health': 'HEALTHY',
+ 'networkEndpoints': [{
+ 'ipAddress': '10.2.3.4',
+ 'port': 8470,
+ }]
}
}
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
- tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'],
+ tpu='test-tpu-1',
+ coordinator_address='10.128.1.5:10203',
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
expected_proto = """
- job {
- name: 'tpu_worker'
- tasks {
- key: 0
- value: '10.1.2.3:8470'
- }
- }
+ job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } }
+ job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+ self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master())
- def testGetMasterMultipleEntries(self):
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testPodResolution(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
- 'ipAddress': '10.1.2.3',
- 'port': '8470',
- 'health': 'HEALTHY'
- },
- 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
- 'ipAddress': '10.4.5.6',
- 'port': '8470',
- 'health': 'HEALTHY'
+ 'health':
+ 'HEALTHY',
+ 'networkEndpoints': [
+ {
+ 'ipAddress': '10.2.3.4',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.5',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.6',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.7',
+ 'port': 8470,
+ },
+ ]
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ tpu='test-tpu-1',
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'coordinator',
+ tasks { key: 0 value: '10.128.1.2:%s'}
+ }
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.2.3.4:8470' }
+ tasks { key: 1 value: '10.2.3.5:8470' }
+ tasks { key: 2 value: '10.2.3.6:8470' }
+ tasks { key: 3 value: '10.2.3.7:8470' }
+ }
+ """ % tpu_cluster_resolver._coordinator_port
+ self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto))
+
+ def testPodResolutionNoCoordinator(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'health':
+ 'HEALTHY',
+ 'networkEndpoints': [
+ {
+ 'ipAddress': '10.2.3.4',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.5',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.6',
+ 'port': 8470,
+ },
+ {
+ 'ipAddress': '10.2.3.7',
+ 'port': 8470,
+ },
+ ]
}
}
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
- tpu_names=['test-tpu-2', 'test-tpu-1'],
+ tpu='test-tpu-1',
+ coordinator_name=None,
credentials=None,
service=self.mock_service_client(tpu_map=tpu_map))
- self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master())
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.2.3.4:8470' }
+ tasks { key: 1 value: '10.2.3.5:8470' }
+ tasks { key: 2 value: '10.2.3.6:8470' }
+ tasks { key: 3 value: '10.2.3.7:8470' }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
def testGetMasterNoEntries(self):
tpu_map = {}
+ with self.assertRaises(ValueError):
+ TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu=[],
+ coordinator_name=None,
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ # TODO(saeta): Convert to parameterized test when included in OSS TF.
+ def verifyShouldResolve(self, tpu, should_resolve):
tpu_cluster_resolver = TPUClusterResolver(
project='test-project',
zone='us-central1-c',
- tpu_names=[],
+ tpu=tpu,
+ coordinator_name=None,
credentials=None,
- service=self.mock_service_client(tpu_map=tpu_map))
- with self.assertRaises(ValueError):
- tpu_cluster_resolver.get_master()
+ service=self.mock_service_client(tpu_map={}))
+ self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(),
+ "TPU: '%s'" % tpu)
+
+ def testShouldResolveNoName(self):
+ self.verifyShouldResolve('', False)
+
+ def testShouldResolveLocal(self):
+ self.verifyShouldResolve('local', False)
+
+ def testShouldResolveGrpc(self):
+ self.verifyShouldResolve('grpc://10.1.2.3:8470', False)
+
+ def testShouldResolveBns(self):
+ self.verifyShouldResolve('/bns/foo/bar', False)
+
+ def testShouldResolveName(self):
+ self.verifyShouldResolve('mytpu', True)
+
+ def testShouldResolveList(self):
+ self.verifyShouldResolve(['myothertpu'], True)
+
+ def testShouldResolveGrpcPrefix(self):
+ self.verifyShouldResolve('grpctpu', True)
+
+ def testNoCallComputeMetadata(self):
+ tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar')
+ self.assertEqual(compat.as_bytes('/bns/foo/bar'),
+ tpu_cluster_resolver.master())
+ self.assertEqual(
+ server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec())
+
if __name__ == '__main__':
test.main()