aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-02-02 16:18:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-02 16:23:11 -0800
commit23f0529f52f2ae9615e465f804e5622cad4aeb8f (patch)
tree51602d2b2df031ee62580d3edb514294ac4fce02 /tensorflow/contrib/cluster_resolver
parent048b19f11e1ba1fa76c0fd508fed9007c852e8df (diff)
Tpu cluster resolver only returns TF server addresses for 'HEALTHY' tpu nodes.
PiperOrigin-RevId: 184350480
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py5
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py55
2 files changed, 52 insertions, 8 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 2e75ac226e..a6a6e642e4 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -143,7 +143,8 @@ class TPUClusterResolver(ClusterResolver):
request = self._service.projects().locations().nodes().get(name=full_name)
response = request.execute()
- instance_url = '%s:%s' % (response['ipAddress'], response['port'])
- worker_list.append(instance_url)
+ if 'health' in response and response['health'] == 'HEALTHY':
+ instance_url = '%s:%s' % (response['ipAddress'], response['port'])
+ worker_list.append(instance_url)
return ClusterSpec({self._job_name: worker_list})
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 0c4730613a..4fd34629cf 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -105,7 +105,8 @@ class TPUClusterResolverTest(test.TestCase):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
}
}
@@ -126,7 +127,8 @@ class TPUClusterResolverTest(test.TestCase):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
}
}
@@ -147,11 +149,13 @@ class TPUClusterResolverTest(test.TestCase):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
},
'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
'ipAddress': '10.4.5.6',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
}
}
@@ -169,15 +173,54 @@ class TPUClusterResolverTest(test.TestCase):
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+ def testHealthyTpuNodeRetrieval(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470',
+ 'health': 'HEALTHY'
+ },
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
+ 'ipAddress': '10.4.5.6',
+ 'port': '8470',
+ },
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': {
+ 'ipAddress': '10.7.8.9',
+ 'port': '8470',
+ 'health': 'UNHEALTHY'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project='test-project',
+ zone='us-central1-c',
+ tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'],
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'tpu_worker'
+ tasks {
+ key: 0
+ value: '10.1.2.3:8470'
+ }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
def testGetMasterMultipleEntries(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
'ipAddress': '10.1.2.3',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
},
'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': {
'ipAddress': '10.4.5.6',
- 'port': '8470'
+ 'port': '8470',
+ 'health': 'HEALTHY'
}
}