aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar Brennan Saeta <saeta@google.com>2018-04-06 09:26:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-06 09:28:43 -0700
commit218647db25d1e754ad85fd1fa8a0960b82ae83bf (patch)
tree2e43dd2461814c84cc0db55662188a803dada27b /tensorflow/contrib/cluster_resolver
parentc5a16fa1c91a0d1cf3d5b432d70b4e8fe47b88cd (diff)
[TPUClusterResolver] Start a TFServer when running in GKE
This change allows advanced input pipelines (e.g. StreamingFilesDataset, or split-pipelines that use py_func's) to run in GKE- and GKE-like enviornments. PiperOrigin-RevId: 191897639
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py75
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py8
2 files changed, 51 insertions, 32 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 300b19733e..a520a06bd7 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -61,11 +61,13 @@ class TPUClusterResolver(ClusterResolver):
return False
return True
- def _inGke(self):
+ @staticmethod
+ def _inGke():
"""When running in GKE, the environment variable will be set."""
return _GKE_ENV_VARIABLE in os.environ
- def _gkeMaster(self):
+ @staticmethod
+ def _gkeMaster():
return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
def __init__(self,
@@ -119,8 +121,9 @@ class TPUClusterResolver(ClusterResolver):
'Using multiple TPUs in a single session is not yet implemented')
tpu = tpu[0]
+ in_gke = self._inGke()
# When using GKE with Cloud TPUs, the env variable will be set.
- if tpu is None and self._inGke():
+ if tpu is None and in_gke:
tpu = self._gkeMaster()
self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes
@@ -158,7 +161,8 @@ class TPUClusterResolver(ClusterResolver):
self._service = service
self._coordinator_name = coordinator_name
- if coordinator_name and not coordinator_address and should_resolve:
+ if coordinator_name and not coordinator_address and (should_resolve or
+ in_gke):
self._start_local_server()
else:
self._coordinator_address = coordinator_address
@@ -204,31 +208,50 @@ class TPUClusterResolver(ClusterResolver):
Raises:
RuntimeError: If the provided TPU is not healthy.
"""
- if not self._shouldResolve():
- return server_lib.ClusterSpec({})
-
- full_name = 'projects/%s/locations/%s/nodes/%s' % (
- self._project, self._zone, compat.as_text(self._tpu))
- request = self._service.projects().locations().nodes().get(name=full_name)
- response = request.execute()
-
- if 'health' in response and response['health'] != 'HEALTHY':
- raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
- response['health']))
-
- if 'networkEndpoints' in response:
- worker_list = [
- '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
- for endpoint in response['networkEndpoints']
- ]
+ ############################################################################
+ # There are 5 potential cases this code must handle:
+ # 1. [Normal case.] We should resolve the TPU name to a set of tasks, and
+ # a. Create a ClusterSpec that includes the coordinator job
+ # b. Create a ClusterSpec without the coordinator job.
+ # 2. [GKE / No API Access.] We should not resolve the TPU name to a set of
+ # tasks and
+ # a. Create a ClusterSpec with the coordinator
+ # b. Create a ClusterSpec without the coordinator
+ # 3. [Other (legacy non-gRPC).] We should return an empty ClusterSpec.
+ ############################################################################
+
+ if self._shouldResolve():
+ # Case 1.
+ full_name = 'projects/%s/locations/%s/nodes/%s' % (
+ self._project, self._zone, compat.as_text(self._tpu))
+ request = self._service.projects().locations().nodes().get(name=full_name)
+ response = request.execute()
+
+ if 'health' in response and response['health'] != 'HEALTHY':
+ raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu,
+ response['health']))
+
+ if 'networkEndpoints' in response:
+ worker_list = [
+ '%s:%s' % (endpoint['ipAddress'], endpoint['port'])
+ for endpoint in response['networkEndpoints']
+ ]
+ else:
+ # Fall back to the deprecated response format
+ instance_url = '%s:%s' % (response['ipAddress'], response['port'])
+ worker_list = [instance_url]
+
+ cluster_spec = {self._job_name: worker_list}
else:
- # Fall back to the deprecated response format
- instance_url = '%s:%s' % (response['ipAddress'], response['port'])
- worker_list = [instance_url]
-
- cluster_spec = {self._job_name: worker_list}
+ if not self._tpu.startswith(compat.as_bytes('grpc://')):
+ # Case 3.
+ return server_lib.ClusterSpec({})
+ # Case 2.
+ cluster_spec = {self._job_name: [self._tpu[len(
+ compat.as_bytes('grpc://')):]]}
if self._coordinator_address:
+ # {1, 2}.a
cluster_spec[self._coordinator_name] = [self._coordinator_address]
return server_lib.ClusterSpec(cluster_spec)
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 48c3f6bb4f..cfddca1063 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -358,14 +358,10 @@ class TPUClusterResolverTest(test.TestCase):
def testGkeEnvironment(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
- tpu_cluster_resolver = TPUClusterResolver()
- self.assertTrue(tpu_cluster_resolver._inGke())
+ self.assertTrue(TPUClusterResolver._inGke())
self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(tpu_cluster_resolver._gkeMaster()))
- self.assertEqual(
- compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(tpu_cluster_resolver.get_master()))
+ compat.as_bytes(TPUClusterResolver._gkeMaster()))
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']