From 3a8eaaf6a238e238a7adac9886b1569d7e43ae23 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 2 Nov 2017 16:32:24 -0700 Subject: Add a new method `get_master` to `TPUClusterResolver` such that users can easily specify the grpc connection string using ClusterResolvers rather than specifying the IP address manually. Also fixes a bug in the `TPUClusterResolverTest` that caused tests to not run at all. PiperOrigin-RevId: 174398488 --- .../python/training/tpu_cluster_resolver.py | 19 ++++++ .../python/training/tpu_cluster_resolver_test.py | 67 +++++++++++++++++++++- 2 files changed, 83 insertions(+), 3 deletions(-) (limited to 'tensorflow/contrib/cluster_resolver') diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index d76ddf8c65..f0144e9faa 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -90,6 +90,25 @@ class TPUClusterResolver(ClusterResolver): else: self._service = service + def get_master(self): + """Get the ClusterSpec grpc master path. + + This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the + ClusterSpec returned by the cluster_spec function. This is suitable for use + for the `master` argument in tf.Session() when you are using one TPU. + + Returns: + string, the grpc path of the first instance in the ClusterSpec. + + Raises: + ValueError: If none of the TPUs specified exists. + """ + job_tasks = self.cluster_spec().job_tasks(self._job_name) + if not job_tasks: + raise ValueError('No TPUs exists with the specified names exist.') + + return 'grpc://' + job_tasks[0] + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5bd5cd1a87..db7419be06 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -26,6 +26,28 @@ from tensorflow.python.training import server_lib mock = test.mock +class MockRequestClass(object): + + def __init__(self, name, tpu_map): + self._name = name + self._tpu_map = tpu_map + + def execute(self): + if self._name in self._tpu_map: + return self._tpu_map[self._name] + else: + raise KeyError('Resource %s was not found' % self._name) + + +class MockNodeClass(object): + + def __init__(self, tpu_map): + self._tpu_map = tpu_map + + def get(self, name): + return MockRequestClass(name, self._tpu_map) + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -56,11 +78,15 @@ class TPUClusterResolverTest(test.TestCase): if tpu_map is None: tpu_map = {} - def get_side_effect(name): - return tpu_map[name] + mock_locations = mock.MagicMock() + mock_locations.nodes.return_value = MockNodeClass(tpu_map) + + mock_project = mock.MagicMock() + mock_project.locations.return_value = mock_locations mock_client = mock.MagicMock() - mock_client.projects.locations.nodes.get.side_effect = get_side_effect + mock_client.projects.return_value = mock_project + return mock_client def testSimpleSuccessfulRetrieval(self): @@ -109,3 +135,38 @@ class TPUClusterResolverTest(test.TestCase): tasks { key: 1 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testGetMasterMultipleEntries(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + def testGetMasterNoEntries(self): + tpu_map = {} + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=[], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + with self.assertRaises(ValueError): + tpu_cluster_resolver.get_master() + +if __name__ == '__main__': + test.main() -- cgit v1.2.3