aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Frank Chen <frankchn@google.com>2017-11-02 16:32:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 16:36:31 -0700
commit3a8eaaf6a238e238a7adac9886b1569d7e43ae23 (patch)
tree59d49898ae364804d6c44b97c06c61e792264de2 /tensorflow/contrib/cluster_resolver
parentbae9ee3da5117d980677451b174115f750220408 (diff)
Add a new method `get_master` to `TPUClusterResolver` such that users can easily specify the grpc connection string using ClusterResolvers rather than specifying the IP address manually.
Also fixes a bug in the `TPUClusterResolverTest` that caused tests to not run at all. PiperOrigin-RevId: 174398488
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py19
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py67
2 files changed, 83 insertions, 3 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index d76ddf8c65..f0144e9faa 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -90,6 +90,25 @@ class TPUClusterResolver(ClusterResolver):
else:
self._service = service
+ def get_master(self):
+ """Get the ClusterSpec grpc master path.
+
+ This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the
+ ClusterSpec returned by the cluster_spec function. This is suitable for use
+ for the `master` argument in tf.Session() when you are using one TPU.
+
+ Returns:
+ string, the grpc path of the first instance in the ClusterSpec.
+
+ Raises:
+ ValueError: If none of the TPUs specified exists.
+ """
+ job_tasks = self.cluster_spec().job_tasks(self._job_name)
+ if not job_tasks:
+ raise ValueError('No TPUs exists with the specified names exist.')
+
+ return 'grpc://' + job_tasks[0]
+
def cluster_spec(self):
"""Returns a ClusterSpec object based on the latest TPU information.
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 5bd5cd1a87..db7419be06 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -26,6 +26,28 @@ from tensorflow.python.training import server_lib
mock = test.mock
+class MockRequestClass(object):
+
+ def __init__(self, name, tpu_map):
+ self._name = name
+ self._tpu_map = tpu_map
+
+ def execute(self):
+ if self._name in self._tpu_map:
+ return self._tpu_map[self._name]
+ else:
+ raise KeyError('Resource %s was not found' % self._name)
+
+
+class MockNodeClass(object):
+
+ def __init__(self, tpu_map):
+ self._tpu_map = tpu_map
+
+ def get(self, name):
+ return MockRequestClass(name, self._tpu_map)
+
+
class TPUClusterResolverTest(test.TestCase):
def _verifyClusterSpecEquality(self, cluster_spec, expected_proto):
@@ -56,11 +78,15 @@ class TPUClusterResolverTest(test.TestCase):
if tpu_map is None:
tpu_map = {}
- def get_side_effect(name):
- return tpu_map[name]
+ mock_locations = mock.MagicMock()
+ mock_locations.nodes.return_value = MockNodeClass(tpu_map)
+
+ mock_project = mock.MagicMock()
+ mock_project.locations.return_value = mock_locations
mock_client = mock.MagicMock()
- mock_client.projects.locations.nodes.get.side_effect = get_side_effect
+ mock_client.projects.return_value = mock_project
+
return mock_client
def testSimpleSuccessfulRetrieval(self):
@@ -109,3 +135,38 @@ class TPUClusterResolverTest(test.TestCase):
tasks { key: 1 value: '10.1.2.3:8470' } }
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ def testGetMasterMultipleEntries(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470'
+ },
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
+ 'ipAddress': '10.4.5.6',
+ 'port': '8470'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu_names=['test-tpu-2', 'test-tpu-1'],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+ self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master())
+
+ def testGetMasterNoEntries(self):
+ tpu_map = {}
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu_names=[],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+ with self.assertRaises(ValueError):
+ tpu_cluster_resolver.get_master()
+
+if __name__ == '__main__':
+ test.main()