diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-09 18:20:53 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-09 18:27:10 -0800 |
commit | 9a88bff7589801727a1649a330fbb70321b04998 (patch) | |
tree | a5aea6fa514720e4e846e86a1cd199391fbc0f23 /tensorflow/tools/dist_test | |
parent | 1fbdd3ff8705566cbb203b7e0c926777289c36b3 (diff) |
A few changes for k8s_tensorflow.py.
Change: 147108270
Diffstat (limited to 'tensorflow/tools/dist_test')
-rwxr-xr-x | tensorflow/tools/dist_test/scripts/k8s_tensorflow.py | 80 |
1 files changed, 53 insertions, 27 deletions
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py index 854c6b832a..fc04a535dc 100755 --- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py @@ -41,13 +41,14 @@ WORKER_RC = ( """apiVersion: v1 kind: ReplicationController metadata: - name: tf-worker{worker_id} + name: {name_prefix}-worker{worker_id} spec: replicas: 1 template: metadata: labels: tf-worker: "{worker_id}" + name-prefix: "{name_prefix}" spec: containers: - name: tf-worker{worker_id} @@ -58,19 +59,17 @@ spec: - --task_id={worker_id} ports: - containerPort: {port} - volumeMounts: - - name: shared - mountPath: /shared - volumes: - - name: shared - hostPath: - path: /shared + env: + - name: POD_NAME_PREFIX + value: {name_prefix} + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] """) WORKER_SVC = ( """apiVersion: v1 kind: Service metadata: - name: tf-worker{worker_id} + name: {name_prefix}-worker{worker_id} labels: tf-worker: "{worker_id}" spec: @@ -84,7 +83,7 @@ WORKER_LB_SVC = ( """apiVersion: v1 kind: Service metadata: - name: tf-worker{worker_id} + name: {name_prefix}-worker{worker_id} labels: tf-worker: "{worker_id}" spec: @@ -98,13 +97,14 @@ PARAM_SERVER_RC = ( """apiVersion: v1 kind: ReplicationController metadata: - name: tf-ps{param_server_id} + name: {name_prefix}-ps{param_server_id} spec: replicas: 1 template: metadata: labels: tf-ps: "{param_server_id}" + name-prefix: "{name_prefix}" spec: containers: - name: tf-ps{param_server_id} @@ -115,19 +115,17 @@ spec: - --task_id={param_server_id} ports: - containerPort: {port} - volumeMounts: - - name: shared - mountPath: /shared - volumes: - - name: shared - hostPath: - path: /shared + env: + - name: POD_NAME_PREFIX + value: {name_prefix} + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] """) PARAM_SERVER_SVC = ( """apiVersion: v1 kind: Service metadata: - name: tf-ps{param_server_id} + name: {name_prefix}-ps{param_server_id} labels: tf-ps: "{param_server_id}" spec: @@ -139,7 +137,7 @@ spec: PARAM_LB_SVC = ("""apiVersion: v1 kind: Service metadata: - name: tf-ps{param_server_id} + name: {name_prefix}-ps{param_server_id} labels: tf-ps: "{param_server_id}" spec: @@ -149,11 +147,15 @@ spec: selector: tf-ps: "{param_server_id}" """) +VOLUME_MOUNTS = '{name: shared, mountPath: /shared}' +VOLUMES = '{name: shared, hostPath: {path: /shared}}' def main(): """Do arg parsing.""" parser = argparse.ArgumentParser() + parser.register( + 'type', 'bool', lambda v: v.lower() in ('true', 't', 'y', 'yes')) parser.add_argument('--num_workers', type=int, default=2, @@ -167,7 +169,7 @@ def main(): default=DEFAULT_PORT, help='GRPC server port (Default: %d)' % DEFAULT_PORT) parser.add_argument('--request_load_balancer', - type=bool, + type='bool', default=False, help='To request worker0 to be exposed on a public IP ' 'address via an external load balancer, enabling you to ' @@ -177,6 +179,16 @@ def main(): default=DEFAULT_DOCKER_IMAGE, help='Override default docker image for the TensorFlow ' 'GRPC server') + parser.add_argument('--name_prefix', + type=str, + default='tf', + help='Prefix for job names. Jobs will be named as ' + '<name_prefix>_worker|ps<task_id>') + parser.add_argument('--use_shared_volume', + type='bool', + default=True, + help='Whether to mount /shared directory from host to ' + 'the pod') args = parser.parse_args() if args.num_workers <= 0: @@ -194,7 +206,9 @@ def main(): args.num_parameter_servers, args.grpc_port, args.request_load_balancer, - args.docker_image) + args.docker_image, + args.name_prefix, + args.use_shared_volume) print(yaml_config) # pylint: disable=superfluous-parens @@ -202,7 +216,9 @@ def GenerateConfig(num_workers, num_param_servers, port, request_load_balancer, - docker_image): + docker_image, + name_prefix, + use_shared_volume): """Generate configuration strings.""" config = '' for worker in range(num_workers): @@ -210,16 +226,21 @@ def GenerateConfig(num_workers, port=port, worker_id=worker, docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', cluster_spec=WorkerClusterSpecString(num_workers, num_param_servers, port)) config += '---\n' if request_load_balancer: config += WORKER_LB_SVC.format(port=port, - worker_id=worker) + worker_id=worker, + name_prefix=name_prefix) else: config += WORKER_SVC.format(port=port, - worker_id=worker) + worker_id=worker, + name_prefix=name_prefix) config += '---\n' for param_server in range(num_param_servers): @@ -227,14 +248,19 @@ def GenerateConfig(num_workers, port=port, param_server_id=param_server, docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', cluster_spec=ParamServerClusterSpecString(num_workers, num_param_servers, port)) config += '---\n' if request_load_balancer: - config += PARAM_LB_SVC.format(port=port, param_server_id=param_server) + config += PARAM_LB_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) else: - config += PARAM_SERVER_SVC.format(port=port, param_server_id=param_server) + config += PARAM_SERVER_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) config += '---\n' return config |