diff options
author | Brennan Saeta <saeta@google.com> | 2018-06-11 12:40:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-11 12:43:37 -0700 |
commit | 32c8013f0ab3feb139648ae759e2d0168fb5dc95 (patch) | |
tree | f5b2eb10f394a3118a0dbf5a8c3bd8a827a8811d /tensorflow/contrib/cluster_resolver | |
parent | 308fe20c728538112cb6ee3c051187977b88773b (diff) |
Check to ensure the Cloud TPU is ready before resolving.
PiperOrigin-RevId: 200095692
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py | 44 |
2 files changed, 48 insertions, 0 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a5a9630a4a..3a1d90e77d 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -256,6 +256,10 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() + if 'state' in response and response['state'] != 'READY': + raise RuntimeError('TPU "%s" is not yet ready; state: "%s"' % + (self._tpu, response['state'])) + if 'health' in response and response['health'] != 'HEALTHY': raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, response['health'])) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5fac55fd02..86e9d9ddad 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -158,6 +158,50 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testUnhealthyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testNotReadyCloudTpu(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'state': 'CREATING' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu='test-tpu-1', + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + with self.assertRaises(RuntimeError): + tpu_cluster_resolver.cluster_spec() + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { |