aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-09 18:20:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-09 18:27:10 -0800
commit9a88bff7589801727a1649a330fbb70321b04998 (patch)
treea5aea6fa514720e4e846e86a1cd199391fbc0f23 /tensorflow/tools/dist_test
parent1fbdd3ff8705566cbb203b7e0c926777289c36b3 (diff)
A few changes for k8s_tensorflow.py.
Change: 147108270
Diffstat (limited to 'tensorflow/tools/dist_test')
-rwxr-xr-xtensorflow/tools/dist_test/scripts/k8s_tensorflow.py80
1 files changed, 53 insertions, 27 deletions
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
index 854c6b832a..fc04a535dc 100755
--- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
@@ -41,13 +41,14 @@ WORKER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
- name: tf-worker{worker_id}
+ name: {name_prefix}-worker{worker_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-worker: "{worker_id}"
+ name-prefix: "{name_prefix}"
spec:
containers:
- name: tf-worker{worker_id}
@@ -58,19 +59,17 @@ spec:
- --task_id={worker_id}
ports:
- containerPort: {port}
- volumeMounts:
- - name: shared
- mountPath: /shared
- volumes:
- - name: shared
- hostPath:
- path: /shared
+ env:
+ - name: POD_NAME_PREFIX
+ value: {name_prefix}
+ volumeMounts: [{volume_mounts}]
+ volumes: [{volumes}]
""")
WORKER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
- name: tf-worker{worker_id}
+ name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
@@ -84,7 +83,7 @@ WORKER_LB_SVC = (
"""apiVersion: v1
kind: Service
metadata:
- name: tf-worker{worker_id}
+ name: {name_prefix}-worker{worker_id}
labels:
tf-worker: "{worker_id}"
spec:
@@ -98,13 +97,14 @@ PARAM_SERVER_RC = (
"""apiVersion: v1
kind: ReplicationController
metadata:
- name: tf-ps{param_server_id}
+ name: {name_prefix}-ps{param_server_id}
spec:
replicas: 1
template:
metadata:
labels:
tf-ps: "{param_server_id}"
+ name-prefix: "{name_prefix}"
spec:
containers:
- name: tf-ps{param_server_id}
@@ -115,19 +115,17 @@ spec:
- --task_id={param_server_id}
ports:
- containerPort: {port}
- volumeMounts:
- - name: shared
- mountPath: /shared
- volumes:
- - name: shared
- hostPath:
- path: /shared
+ env:
+ - name: POD_NAME_PREFIX
+ value: {name_prefix}
+ volumeMounts: [{volume_mounts}]
+ volumes: [{volumes}]
""")
PARAM_SERVER_SVC = (
"""apiVersion: v1
kind: Service
metadata:
- name: tf-ps{param_server_id}
+ name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
@@ -139,7 +137,7 @@ spec:
PARAM_LB_SVC = ("""apiVersion: v1
kind: Service
metadata:
- name: tf-ps{param_server_id}
+ name: {name_prefix}-ps{param_server_id}
labels:
tf-ps: "{param_server_id}"
spec:
@@ -149,11 +147,15 @@ spec:
selector:
tf-ps: "{param_server_id}"
""")
+VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
+VOLUMES = '{name: shared, hostPath: {path: /shared}}'
def main():
"""Do arg parsing."""
parser = argparse.ArgumentParser()
+ parser.register(
+ 'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes'))
parser.add_argument('--num_workers',
type=int,
default=2,
@@ -167,7 +169,7 @@ def main():
default=DEFAULT_PORT,
help='GRPC server port (Default: %d)' % DEFAULT_PORT)
parser.add_argument('--request_load_balancer',
- type=bool,
+ type='bool',
default=False,
help='To request worker0 to be exposed on a public IP '
'address via an external load balancer, enabling you to '
@@ -177,6 +179,16 @@ def main():
default=DEFAULT_DOCKER_IMAGE,
help='Override default docker image for the TensorFlow '
'GRPC server')
+ parser.add_argument('--name_prefix',
+ type=str,
+ default='tf',
+ help='Prefix for job names. Jobs will be named as '
+ '<name_prefix>_worker|ps<task_id>')
+ parser.add_argument('--use_shared_volume',
+ type='bool',
+ default=True,
+ help='Whether to mount /shared directory from host to '
+ 'the pod')
args = parser.parse_args()
if args.num_workers <= 0:
@@ -194,7 +206,9 @@ def main():
args.num_parameter_servers,
args.grpc_port,
args.request_load_balancer,
- args.docker_image)
+ args.docker_image,
+ args.name_prefix,
+ args.use_shared_volume)
print(yaml_config) # pylint: disable=superfluous-parens
@@ -202,7 +216,9 @@ def GenerateConfig(num_workers,
num_param_servers,
port,
request_load_balancer,
- docker_image):
+ docker_image,
+ name_prefix,
+ use_shared_volume):
"""Generate configuration strings."""
config = ''
for worker in range(num_workers):
@@ -210,16 +226,21 @@ def GenerateConfig(num_workers,
port=port,
worker_id=worker,
docker_image=docker_image,
+ name_prefix=name_prefix,
+ volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
+ volumes=VOLUMES if use_shared_volume else '',
cluster_spec=WorkerClusterSpecString(num_workers,
num_param_servers,
port))
config += '---\n'
if request_load_balancer:
config += WORKER_LB_SVC.format(port=port,
- worker_id=worker)
+ worker_id=worker,
+ name_prefix=name_prefix)
else:
config += WORKER_SVC.format(port=port,
- worker_id=worker)
+ worker_id=worker,
+ name_prefix=name_prefix)
config += '---\n'
for param_server in range(num_param_servers):
@@ -227,14 +248,19 @@ def GenerateConfig(num_workers,
port=port,
param_server_id=param_server,
docker_image=docker_image,
+ name_prefix=name_prefix,
+ volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
+ volumes=VOLUMES if use_shared_volume else '',
cluster_spec=ParamServerClusterSpecString(num_workers,
num_param_servers,
port))
config += '---\n'
if request_load_balancer:
- config += PARAM_LB_SVC.format(port=port, param_server_id=param_server)
+ config += PARAM_LB_SVC.format(
+ port=port, param_server_id=param_server, name_prefix=name_prefix)
else:
- config += PARAM_SERVER_SVC.format(port=port, param_server_id=param_server)
+ config += PARAM_SERVER_SVC.format(
+ port=port, param_server_id=param_server, name_prefix=name_prefix)
config += '---\n'
return config