aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cluster_resolver
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 11:34:51 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 11:37:16 -0700
commitc5436b90adff058500e88b497fc4f7a0b0379d28 (patch)
tree58beb59a11492b3f01051c383357d92160125d1e /tensorflow/contrib/cluster_resolver
parentdc7821ccf42ada3f85ca1c6e8228f0a42e61b93c (diff)
Support Cloud TPU Pod in GKE environment.
PiperOrigin-RevId: 200251004
Diffstat (limited to 'tensorflow/contrib/cluster_resolver')
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py17
-rw-r--r--tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py54
2 files changed, 62 insertions, 9 deletions
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
index 3a1d90e77d..8f521ffee4 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py
@@ -36,6 +36,7 @@ except ImportError:
_GKE_ENV_VARIABLE = 'KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'
+_ENDPOINTS_SEPARATOR = ','
_DEFAULT_ENV_VARIABLE = 'TPU_NAME'
_DISCOVERY_SERVICE_URL_ENV_VARIABLE = 'TPU_API_DISCOVERY_URL'
@@ -69,8 +70,8 @@ class TPUClusterResolver(ClusterResolver):
return _GKE_ENV_VARIABLE in os.environ
@staticmethod
- def _gkeMaster():
- return os.environ[_GKE_ENV_VARIABLE].split(',')[0]
+ def _gkeEndpoints():
+ return os.environ[_GKE_ENV_VARIABLE]
@staticmethod
def _envVarFallback():
@@ -143,7 +144,7 @@ class TPUClusterResolver(ClusterResolver):
# When using GKE with Cloud TPUs, the env variable will be set.
if tpu is None:
if in_gke:
- tpu = self._gkeMaster()
+ tpu = self._gkeEndpoints()
else:
tpu = self._envVarFallback()
@@ -214,7 +215,7 @@ class TPUClusterResolver(ClusterResolver):
ValueError: If none of the TPUs specified exists.
"""
if not self._shouldResolve():
- return self._tpu
+ return self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))[0]
job_tasks = self.cluster_spec().job_tasks(self._job_name)
if not job_tasks:
@@ -280,8 +281,12 @@ class TPUClusterResolver(ClusterResolver):
# Case 3.
return None
# Case 2.
- cluster_spec = {self._job_name: [self._tpu[len(
- compat.as_bytes('grpc://')):]]}
+ cluster_spec = {
+ self._job_name: [
+ x[len(compat.as_bytes('grpc://')):]
+ for x in self._tpu.split(compat.as_bytes(_ENDPOINTS_SEPARATOR))
+ ]
+ }
if self._coordinator_address:
# {1, 2}.a
diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
index 86e9d9ddad..ad4f643263 100644
--- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
+++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py
@@ -402,13 +402,61 @@ class TPUClusterResolverTest(test.TestCase):
compat.as_bytes('/bns/foo/bar'), tpu_cluster_resolver.master())
self.assertEqual(None, tpu_cluster_resolver.cluster_spec())
- def testGkeEnvironment(self):
+ def testGkeEnvironmentForDonut(self):
os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = 'grpc://10.120.27.5:8470'
- self.assertTrue('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS' in os.environ)
+
+ self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
+ self.assertTrue(TPUClusterResolver._inGke())
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
+
+ tpu_cluster_resolver = TPUClusterResolver()
+ self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470'),
+ compat.as_bytes(tpu_cluster_resolver.master()))
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.120.27.5:8470' }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
+ del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
+
+ def testGkeEnvironmentForPod(self):
+ os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS'] = ('grpc://10.120.27.5:8470,'
+ 'grpc://10.120.27.6:8470,'
+ 'grpc://10.120.27.7:8470,'
+ 'grpc://10.120.27.8:8470')
+
+ self.assertIn('KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS', os.environ)
self.assertTrue(TPUClusterResolver._inGke())
self.assertEqual(
+ compat.as_bytes('grpc://10.120.27.5:8470,'
+ 'grpc://10.120.27.6:8470,'
+ 'grpc://10.120.27.7:8470,'
+ 'grpc://10.120.27.8:8470'),
+ compat.as_bytes(TPUClusterResolver._gkeEndpoints()))
+
+ tpu_cluster_resolver = TPUClusterResolver()
+ self.assertEqual(
compat.as_bytes('grpc://10.120.27.5:8470'),
- compat.as_bytes(TPUClusterResolver._gkeMaster()))
+ compat.as_bytes(tpu_cluster_resolver.master()))
+ actual_cluster_spec = tpu_cluster_resolver.cluster_spec()
+ expected_proto = """
+ job {
+ name: 'worker'
+ tasks { key: 0 value: '10.120.27.5:8470' }
+ tasks { key: 1 value: '10.120.27.6:8470' }
+ tasks { key: 2 value: '10.120.27.7:8470' }
+ tasks { key: 3 value: '10.120.27.8:8470' }
+ }
+ """
+ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto)
+
del os.environ['KUBE_GOOGLE_CLOUD_TPU_ENDPOINTS']
def testDiscoveryUrl(self):