aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-02-17 15:58:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-02-17 16:10:08 -0800
commitb06281ba47595d07f58766d6477f896e39eb8a5e (patch)
tree5dd68f1a1d35cfadcf8b792d217974c83b4ee5c3 /tensorflow/tools/dist_test
parenta6421c4dda1a83ea975bae545df1de16d38726b0 (diff)
Split out most of k8s_tensorflow into a library and add a way to pass any
environment variables. Add benchmark_util library that would use environemnt variable to decide on a storage location. Change: 147890534
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/python/BUILD23
-rw-r--r--tensorflow/tools/dist_test/python/benchmark_util.py77
-rw-r--r--tensorflow/tools/dist_test/python/benchmark_util_test.py58
-rw-r--r--tensorflow/tools/dist_test/scripts/BUILD21
-rwxr-xr-xtensorflow/tools/dist_test/scripts/k8s_tensorflow.py230
-rw-r--r--tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py309
-rw-r--r--tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py132
7 files changed, 631 insertions, 219 deletions
diff --git a/tensorflow/tools/dist_test/python/BUILD b/tensorflow/tools/dist_test/python/BUILD
new file mode 100644
index 0000000000..a2927c5bfc
--- /dev/null
+++ b/tensorflow/tools/dist_test/python/BUILD
@@ -0,0 +1,23 @@
+# Python tools for running distributed benchmarks.
+
+licenses(["notice"]) # Apache 2.0
+
+py_library(
+ name = "benchmark_util_lib",
+ srcs = ["benchmark_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "benchmark_util_test",
+ srcs = ["benchmark_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":benchmark_util_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/tools/dist_test/python/benchmark_util.py b/tensorflow/tools/dist_test/python/benchmark_util.py
new file mode 100644
index 0000000000..5727908d6f
--- /dev/null
+++ b/tensorflow/tools/dist_test/python/benchmark_util.py
@@ -0,0 +1,77 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Provides helper functions for distributed benchmarks running on Jenkins."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import calendar
+from collections import namedtuple
+import os
+
+from google.protobuf import json_format
+
+from tensorflow.core.util import test_log_pb2
+from tensorflow.python.platform import gfile
+
+
+_OUTPUT_FILE_ENV_VAR = 'TF_DIST_BENCHMARK_RESULTS_FILE'
+_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME'
+
+
+# Represents a single timing entry where
+# - name is a string
+# - timing is the latency to track (for e.g. mean time per iter)
+# - iters is the number of iterations
+TimingEntry = namedtuple(
+ 'TimingEntry', ['name', 'timing', 'iters'])
+
+
+def store_data_in_json(timing_entries, start_time, output_file=None):
+ """Stores benchmark results in JSON format.
+
+ Args:
+ timing_entries: list of TimingEntry objects.
+ start_time: (datetime) start time of the test run.
+ output_file: if specified, writes benchmark results to output_file.
+ If not specified, writes results to the file specified by
+ BENCHMARK_RESULTS_FILE environment variable.
+
+ Raises:
+ ValueError: when neither output_file is passed in nor
+ BENCHMARK_RESULTS_FILE is set.
+ """
+ test_result = test_log_pb2.TestResults(
+ start_time=calendar.timegm(start_time.timetuple()))
+ if not output_file:
+ if _OUTPUT_FILE_ENV_VAR not in os.environ:
+ raise ValueError('Could not determine location to store results at.')
+ output_file = os.environ[_OUTPUT_FILE_ENV_VAR]
+
+ with gfile.Open(output_file, 'wb') as jsonfile:
+ if _TEST_NAME_ENV_VAR in os.environ:
+ test_result.name = os.environ['POD_NAME_PREFIX']
+ else:
+ test_result.name = 'TestBenchmark'
+
+ for timing_entry in timing_entries:
+ test_result.entries.entry.add(
+ name=timing_entry.name,
+ iters=timing_entry.iters,
+ wall_time=timing_entry.timing
+ )
+ json_test_results = json_format.MessageToJson(test_result)
+ jsonfile.write(json_test_results)
diff --git a/tensorflow/tools/dist_test/python/benchmark_util_test.py b/tensorflow/tools/dist_test/python/benchmark_util_test.py
new file mode 100644
index 0000000000..e8a71f3153
--- /dev/null
+++ b/tensorflow/tools/dist_test/python/benchmark_util_test.py
@@ -0,0 +1,58 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.tools.dist_test.python.benchmark_util."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import datetime
+import json
+import os
+
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+from tensorflow.tools.dist_test.python import benchmark_util
+
+
+class BenchmarkUtilTest(googletest.TestCase):
+
+ def testStoreDataWithNoEntries(self):
+ output_file = os.path.join(test.get_temp_dir(), 'test_output1.json')
+ timing_entries = []
+ benchmark_util.store_data_in_json(
+ timing_entries, datetime.date(2017, 1, 1), output_file)
+ json_output = json.loads(open(output_file, 'r').read())
+ self.assertEquals('TestBenchmark', json_output['name'])
+ self.assertEquals(u'1483228800', json_output['startTime'])
+
+ def testStoreDataWithEntries(self):
+ output_file = os.path.join(test.get_temp_dir(), 'test_output2.json')
+ timing_entries = [
+ benchmark_util.TimingEntry('test', 0.1, 1)]
+ benchmark_util.store_data_in_json(
+ timing_entries, datetime.date(2017, 1, 1), output_file)
+ json_output = json.loads(open(output_file, 'r').read())
+
+ self.assertEquals(1, len(json_output['entries']['entry']))
+ self.assertEquals('test', json_output['entries']['entry'][0]['name'])
+ self.assertEquals(0.1, json_output['entries']['entry'][0]['wallTime'])
+ self.assertEquals(u'1', json_output['entries']['entry'][0]['iters'])
+ self.assertEquals(u'1483228800', json_output['startTime'])
+ self.assertEquals('TestBenchmark', json_output['name'])
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD
new file mode 100644
index 0000000000..abcd521a26
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/BUILD
@@ -0,0 +1,21 @@
+# Tools for running distributed benchmarks.
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["k8s_tensorflow.py"])
+
+py_library(
+ name = "k8s_tensorflow_lib",
+ srcs = ["k8s_tensorflow_lib.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_test(
+ name = "k8s_tensorflow_test",
+ srcs = ["k8s_tensorflow_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":k8s_tensorflow_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
index 3693c177e2..b325f030e3 100755
--- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
@@ -25,6 +25,8 @@ from __future__ import print_function
import argparse
import sys
+import k8s_tensorflow_lib
+
# Note: It is intentional that we do not import tensorflow in this script. The
# machine that launches a TensorFlow k8s cluster does not have to have the
# Python package of TensorFlow installed on it.
@@ -33,125 +35,6 @@ import sys
DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
DEFAULT_PORT = 2222
-# TODO(cais): Consider adding resource requests/limits to the pods.
-
-# Worker pods will mount host volume /shared, as a convenient way to create
-# shared storage among workers during local tests.
-WORKER_RC = (
- """apiVersion: v1
-kind: ReplicationController
-metadata:
- name: {name_prefix}-worker{worker_id}
-spec:
- replicas: 1
- template:
- metadata:
- labels:
- tf-worker: "{worker_id}"
- name-prefix: "{name_prefix}"
- job: "worker"
- spec:
- containers:
- - name: tf-worker{worker_id}
- image: {docker_image}
- args:
- - --cluster_spec={cluster_spec}
- - --job_name=worker
- - --task_id={worker_id}
- ports:
- - containerPort: {port}
- env:
- - name: POD_NAME_PREFIX
- value: {name_prefix}
- volumeMounts: [{volume_mounts}]
- volumes: [{volumes}]
-""")
-WORKER_SVC = (
- """apiVersion: v1
-kind: Service
-metadata:
- name: {name_prefix}-worker{worker_id}
- labels:
- tf-worker: "{worker_id}"
-spec:
- ports:
- - port: {port}
- targetPort: {port}
- selector:
- tf-worker: "{worker_id}"
-""")
-WORKER_LB_SVC = (
- """apiVersion: v1
-kind: Service
-metadata:
- name: {name_prefix}-worker{worker_id}
- labels:
- tf-worker: "{worker_id}"
-spec:
- type: LoadBalancer
- ports:
- - port: {port}
- selector:
- tf-worker: "{worker_id}"
-""")
-PARAM_SERVER_RC = (
- """apiVersion: v1
-kind: ReplicationController
-metadata:
- name: {name_prefix}-ps{param_server_id}
-spec:
- replicas: 1
- template:
- metadata:
- labels:
- tf-ps: "{param_server_id}"
- name-prefix: "{name_prefix}"
- job: "ps"
- spec:
- containers:
- - name: tf-ps{param_server_id}
- image: {docker_image}
- args:
- - --cluster_spec={cluster_spec}
- - --job_name=ps
- - --task_id={param_server_id}
- ports:
- - containerPort: {port}
- env:
- - name: POD_NAME_PREFIX
- value: {name_prefix}
- volumeMounts: [{volume_mounts}]
- volumes: [{volumes}]
-""")
-PARAM_SERVER_SVC = (
- """apiVersion: v1
-kind: Service
-metadata:
- name: {name_prefix}-ps{param_server_id}
- labels:
- tf-ps: "{param_server_id}"
-spec:
- ports:
- - port: {port}
- selector:
- tf-ps: "{param_server_id}"
-""")
-PARAM_LB_SVC = ("""apiVersion: v1
-kind: Service
-metadata:
- name: {name_prefix}-ps{param_server_id}
- labels:
- tf-ps: "{param_server_id}"
-spec:
- type: LoadBalancer
- ports:
- - port: {port}
- selector:
- tf-ps: "{param_server_id}"
-""")
-VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
-VOLUMES = '{name: shared, hostPath: {path: /shared}}'
-
def main():
"""Do arg parsing."""
@@ -204,108 +87,17 @@ def main():
sys.exit(1)
# Generate contents of yaml config
- yaml_config = GenerateConfig(args.num_workers,
- args.num_parameter_servers,
- args.grpc_port,
- args.request_load_balancer,
- args.docker_image,
- args.name_prefix,
- args.use_shared_volume)
+ yaml_config = k8s_tensorflow_lib.GenerateConfig(
+ args.num_workers,
+ args.num_parameter_servers,
+ args.grpc_port,
+ args.request_load_balancer,
+ args.docker_image,
+ args.name_prefix,
+ env_vars=None,
+ use_shared_volume=args.use_shared_volume)
print(yaml_config) # pylint: disable=superfluous-parens
-def GenerateConfig(num_workers,
- num_param_servers,
- port,
- request_load_balancer,
- docker_image,
- name_prefix,
- use_shared_volume):
- """Generate configuration strings."""
- config = ''
- for worker in range(num_workers):
- config += WORKER_RC.format(
- port=port,
- worker_id=worker,
- docker_image=docker_image,
- name_prefix=name_prefix,
- volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
- volumes=VOLUMES if use_shared_volume else '',
- cluster_spec=WorkerClusterSpecString(num_workers,
- num_param_servers,
- port,
- name_prefix))
- config += '---\n'
- if request_load_balancer:
- config += WORKER_LB_SVC.format(port=port,
- worker_id=worker,
- name_prefix=name_prefix)
- else:
- config += WORKER_SVC.format(port=port,
- worker_id=worker,
- name_prefix=name_prefix)
- config += '---\n'
-
- for param_server in range(num_param_servers):
- config += PARAM_SERVER_RC.format(
- port=port,
- param_server_id=param_server,
- docker_image=docker_image,
- name_prefix=name_prefix,
- volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
- volumes=VOLUMES if use_shared_volume else '',
- cluster_spec=ParamServerClusterSpecString(num_workers,
- num_param_servers,
- port,
- name_prefix))
- config += '---\n'
- if request_load_balancer:
- config += PARAM_LB_SVC.format(
- port=port, param_server_id=param_server, name_prefix=name_prefix)
- else:
- config += PARAM_SERVER_SVC.format(
- port=port, param_server_id=param_server, name_prefix=name_prefix)
- config += '---\n'
-
- return config
-
-
-def WorkerClusterSpecString(num_workers,
- num_param_servers,
- port,
- name_prefix):
- """Generates worker cluster spec."""
- return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
-
-
-def ParamServerClusterSpecString(num_workers,
- num_param_servers,
- port,
- name_prefix):
- """Generates parameter server spec."""
- return ClusterSpecString(num_workers, num_param_servers, port,
- name_prefix)
-
-
-def ClusterSpecString(num_workers,
- num_param_servers,
- port,
- name_prefix):
- """Generates general cluster spec."""
- spec = 'worker|'
- for worker in range(num_workers):
- spec += '%s-worker%d:%d' % (name_prefix, worker, port)
- if worker != num_workers-1:
- spec += ';'
-
- spec += ',ps|'
- for param_server in range(num_param_servers):
- spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
- if param_server != num_param_servers-1:
- spec += ';'
-
- return spec
-
-
if __name__ == '__main__':
main()
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
new file mode 100644
index 0000000000..8adbe387ba
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py
@@ -0,0 +1,309 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generates YAML configuration files for distributed TensorFlow workers.
+
+The workers will be run in a Kubernetes (k8s) container cluster.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# Note: It is intentional that we do not import tensorflow in this script. The
+# machine that launches a TensorFlow k8s cluster does not have to have the
+# Python package of TensorFlow installed on it.
+
+# TODO(cais): Consider adding resource requests/limits to the pods.
+
+# Worker pods will mount host volume /shared, as a convenient way to create
+# shared storage among workers during local tests.
+WORKER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: {name_prefix}-worker{worker_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-worker: "{worker_id}"
+ name-prefix: "{name_prefix}"
+ job: "worker"
+ spec:
+ containers:
+ - name: tf-worker{worker_id}
+ image: {docker_image}
+ args: [{args}]
+ ports:
+ - containerPort: {port}
+ env: [{env_vars}]
+ volumeMounts: [{volume_mounts}]
+ volumes: [{volumes}]
+""")
+WORKER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: {name_prefix}-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ ports:
+ - port: {port}
+ targetPort: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+WORKER_LB_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: {name_prefix}-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ type: LoadBalancer
+ ports:
+ - port: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+PARAM_SERVER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: {name_prefix}-ps{param_server_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-ps: "{param_server_id}"
+ name-prefix: "{name_prefix}"
+ job: "ps"
+ spec:
+ containers:
+ - name: tf-ps{param_server_id}
+ image: {docker_image}
+ args: [{args}]
+ ports:
+ - containerPort: {port}
+ env: [{env_vars}]
+ volumeMounts: [{volume_mounts}]
+ volumes: [{volumes}]
+""")
+PARAM_SERVER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: {name_prefix}-ps{param_server_id}
+ labels:
+ tf-ps: "{param_server_id}"
+spec:
+ ports:
+ - port: {port}
+ selector:
+ tf-ps: "{param_server_id}"
+""")
+PARAM_LB_SVC = ("""apiVersion: v1
+kind: Service
+metadata:
+ name: {name_prefix}-ps{param_server_id}
+ labels:
+ tf-ps: "{param_server_id}"
+spec:
+ type: LoadBalancer
+ ports:
+ - port: {port}
+ selector:
+ tf-ps: "{param_server_id}"
+""")
+VOLUME_MOUNTS = '{name: shared, mountPath: /shared}'
+VOLUMES = '{name: shared, hostPath: {path: /shared}}'
+_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}'
+_ARG_TEMPLATE = '"--%s=%s"'
+
+
+def GenerateConfig(num_workers,
+ num_param_servers,
+ port,
+ request_load_balancer,
+ docker_image,
+ name_prefix,
+ env_vars=None,
+ use_shared_volume=True,
+ use_cluster_spec=True):
+ """Generate configuration strings.
+
+ Args:
+ num_workers: number of worker jobs.
+ num_param_servers: number of ps server jobs.
+ port: GRPC server port.
+ request_load_balancer: request worker0 to be exposed on a public IP
+ address via an external load balancer.
+ docker_image: docker image to use.
+ name_prefix: name to prepend to pod job names.
+ env_vars: dictionary of environment variables to set.
+ use_shared_volume: whether to add hostPath to /shared directory
+ to the kubernetes config.
+ use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs.
+ If false, pass --worker_hosts and --ps_hosts to worker and ps jobs.
+
+ Returns:
+ Kubernetes yaml config.
+ """
+ if env_vars is None:
+ env_vars = {}
+ env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value)
+ for name, value in env_vars.items()])
+ config = ''
+ common_args = GetCommonArgs(
+ num_workers, num_param_servers, port, name_prefix, use_cluster_spec)
+ for worker in range(num_workers):
+ worker_args = {
+ 'job_name': 'worker',
+ 'task_id': worker
+ }
+ worker_args.update(common_args)
+ arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
+ for name, value in worker_args.items()])
+ config += WORKER_RC.format(
+ port=port,
+ worker_id=worker,
+ docker_image=docker_image,
+ name_prefix=name_prefix,
+ volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
+ volumes=VOLUMES if use_shared_volume else '',
+ args=arg_str,
+ env_vars=env_str)
+ config += '---\n'
+ if request_load_balancer:
+ config += WORKER_LB_SVC.format(port=port,
+ worker_id=worker,
+ name_prefix=name_prefix)
+ else:
+ config += WORKER_SVC.format(port=port,
+ worker_id=worker,
+ name_prefix=name_prefix)
+ config += '---\n'
+
+ for param_server in range(num_param_servers):
+ ps_args = {
+ 'job_name': 'ps',
+ 'task_id': param_server
+ }
+ ps_args.update(common_args)
+ arg_str = ', '.join([_ARG_TEMPLATE % (name, value)
+ for name, value in ps_args.items()])
+ config += PARAM_SERVER_RC.format(
+ port=port,
+ param_server_id=param_server,
+ docker_image=docker_image,
+ name_prefix=name_prefix,
+ volume_mounts=VOLUME_MOUNTS if use_shared_volume else '',
+ volumes=VOLUMES if use_shared_volume else '',
+ args=arg_str,
+ env_vars=env_str)
+ config += '---\n'
+ if request_load_balancer:
+ config += PARAM_LB_SVC.format(
+ port=port, param_server_id=param_server, name_prefix=name_prefix)
+ else:
+ config += PARAM_SERVER_SVC.format(
+ port=port, param_server_id=param_server, name_prefix=name_prefix)
+ config += '---\n'
+
+ return config
+
+
+def WorkerClusterSpecString(num_workers,
+ num_param_servers,
+ port,
+ name_prefix):
+ """Generates worker cluster spec."""
+ return ClusterSpecString(num_workers, num_param_servers, port, name_prefix)
+
+
+def ParamServerClusterSpecString(num_workers,
+ num_param_servers,
+ port,
+ name_prefix):
+ """Generates parameter server spec."""
+ return ClusterSpecString(num_workers, num_param_servers, port,
+ name_prefix)
+
+
+def ClusterSpecString(num_workers,
+ num_param_servers,
+ port,
+ name_prefix):
+ """Generates general cluster spec."""
+ spec = 'worker|'
+ for worker in range(num_workers):
+ spec += '%s-worker%d:%d' % (name_prefix, worker, port)
+ if worker != num_workers-1:
+ spec += ';'
+
+ spec += ',ps|'
+ for param_server in range(num_param_servers):
+ spec += '%s-ps%d:%d' % (name_prefix, param_server, port)
+ if param_server != num_param_servers-1:
+ spec += ';'
+
+ return spec
+
+
+def GetCommonArgs(num_workers,
+ num_param_servers,
+ port,
+ name_prefix,
+ use_cluster_spec):
+ """Get arguments common to both worker and ps jobs.
+
+ Args:
+ num_workers: number of workers.
+ num_param_servers: number of ps servers.
+ port: worker and ps port number.
+ name_prefix: prefix to prepend to job names.
+ use_cluster_spec: if true, pass --cluster_spec argument.
+ If false, parse --worker_hosts and --ps_hosts arguments.
+
+ Returns:
+ A dictionary of argument names mapping to argument values.
+ """
+ common_args = {}
+ if use_cluster_spec:
+ common_args['cluster_spec'] = WorkerClusterSpecString(
+ num_workers,
+ num_param_servers,
+ port,
+ name_prefix)
+ else:
+ common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix)
+ common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix)
+ return common_args
+
+
+def WorkerHosts(num_workers, port, name_prefix):
+ worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port)
+ for i in range(num_workers)]
+ return ','.join(worker_hosts)
+
+
+def PsHosts(num_ps, port, name_prefix):
+ ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port)
+ for i in range(num_ps)]
+ return ','.join(ps_hosts)
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
new file mode 100644
index 0000000000..7d9b3f83f5
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py
@@ -0,0 +1,132 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.platform import googletest
+from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib
+
+
+class K8sTensorflowTest(googletest.TestCase):
+
+ def testGenerateConfig_LoadBalancer(self):
+ # Use loadbalancer
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=True,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False)
+ self.assertTrue('LoadBalancer' in config)
+
+ # Don't use loadbalancer
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=False,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False)
+ self.assertFalse('LoadBalancer' in config)
+
+ def testGenerateConfig_SharedVolume(self):
+ # Use shared directory
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=False,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=True)
+ self.assertTrue('/shared' in config)
+
+ # Don't use shared directory
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=False,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False)
+ self.assertFalse('/shared' in config)
+
+ def testEnvVar(self):
+ # Use loadbalancer
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=True,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False,
+ env_vars={'test1': 'test1_value', 'test2': 'test2_value'})
+ self.assertTrue('{name: "test1", value: "test1_value"}' in config)
+ self.assertTrue('{name: "test2", value: "test2_value"}' in config)
+
+ def testClusterSpec(self):
+ # Use cluster_spec
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=True,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False,
+ use_cluster_spec=True)
+ self.assertFalse('worker_hosts' in config)
+ self.assertFalse('ps_hosts' in config)
+ self.assertTrue(
+ '"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config)
+
+ # Don't use cluster_spec
+ config = k8s_tensorflow_lib.GenerateConfig(
+ num_workers=1,
+ num_param_servers=1,
+ port=5000,
+ request_load_balancer=True,
+ docker_image='test_image',
+ name_prefix='abc',
+ use_shared_volume=False,
+ use_cluster_spec=False)
+ self.assertFalse('cluster_spec' in config)
+ self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config)
+ self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config)
+
+ def testWorkerHosts(self):
+ self.assertEquals(
+ 'test_prefix-worker0:1234',
+ k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix'))
+ self.assertEquals(
+ 'test_prefix-worker0:1234,test_prefix-worker1:1234',
+ k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix'))
+
+ def testPsHosts(self):
+ self.assertEquals(
+ 'test_prefix-ps0:1234,test_prefix-ps1:1234',
+ k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix'))
+
+
+if __name__ == '__main__':
+ googletest.main()