aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-06-11 12:40:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-11 12:43:37 -0700
commit32c8013f0ab3feb139648ae759e2d0168fb5dc95 (patch)
treef5b2eb10f394a3118a0dbf5a8c3bd8a827a8811d /tensorflow/contrib/cluster_resolver
parent308fe20c728538112cb6ee3c051187977b88773b (diff)
Check to ensure the Cloud TPU is ready before resolving.
PiperOrigin-RevId: 200095692
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py4
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py44
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index a5a9630a4a..3a1d90e77d 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -256,6 +256,10 @@ class TPUClusterResolver(ClusterResolver):
request = self._service.projects().locations().nodes().get(name=full_name)
response = request.execute()
+ if 'state' in response and response['state'] != 'READY':
+ raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' %
+ (self._tpu, response['state']))
+
if 'health' in response and response['health'] != 'HEALTHY':
raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
response['health']))
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 5fac55fd02..86e9d9ddad 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase):
"""
self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testUnhealthyCloudTpu(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470',
+ 'health': 'UNHEALTHY'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project=None,
+ zone=None,
+ tpu='test-tpu-1',
+ coordinator_name=None,
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ with self.assertRaises(RuntimeError):
+ tpu_cluster_resolver.cluster_spec()
+
+ @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata',
+ mock_request_compute_metadata)
+ def testNotReadyCloudTpu(self):
+ tpu_map = {
+ 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {
+ 'ipAddress': '10.1.2.3',
+ 'port': '8470',
+ 'state': 'CREATING'
+ }
+ }
+
+ tpu_cluster_resolver = TPUClusterResolver(
+ project=None,
+ zone=None,
+ tpu='test-tpu-1',
+ coordinator_name=None,
+ credentials=None,
+ service=self.mock_service_client(tpu_map=tpu_map))
+
+ with self.assertRaises(RuntimeError):
+ tpu_cluster_resolver.cluster_spec()
+
def testSimpleSuccessfulRetrieval(self):
tpu_map = {
'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': {