diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-02-17 15:58:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-02-17 16:10:08 -0800 |
commit | b06281ba47595d07f58766d6477f896e39eb8a5e (patch) | |
tree | 5dd68f1a1d35cfadcf8b792d217974c83b4ee5c3 /tensorflow/tools/dist_test | |
parent | a6421c4dda1a83ea975bae545df1de16d38726b0 (diff) |
Split out most of k8s_tensorflow into a library and add a way to pass any
environment variables. Add benchmark_util library that would use environemnt
variable to decide on a storage location.
Change: 147890534
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r-- | tensorflow/tools/dist_test/python/BUILD | 23 | ||||
-rw-r--r-- | tensorflow/tools/dist_test/python/benchmark_util.py | 77 | ||||
-rw-r--r-- | tensorflow/tools/dist_test/python/benchmark_util_test.py | 58 | ||||
-rw-r--r-- | tensorflow/tools/dist_test/scripts/BUILD | 21 | ||||
-rwxr-xr-x | tensorflow/tools/dist_test/scripts/k8s_tensorflow.py | 230 | ||||
-rw-r--r-- | tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py | 309 | ||||
-rw-r--r-- | tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py | 132 |
7 files changed, 631 insertions, 219 deletions
diff --git a/tensorflow/tools/dist_test/python/BUILD b/tensorflow/tools/dist_test/python/BUILD new file mode 100644 index 0000000000..a2927c5bfc --- /dev/null +++ b/tensorflow/tools/dist_test/python/BUILD @@ -0,0 +1,23 @@ +# Python tools for running distributed benchmarks. + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "benchmark_util_lib", + srcs = ["benchmark_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:platform", + ], +) + +py_test( + name = "benchmark_util_test", + srcs = ["benchmark_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":benchmark_util_lib", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/tools/dist_test/python/benchmark_util.py b/tensorflow/tools/dist_test/python/benchmark_util.py new file mode 100644 index 0000000000..5727908d6f --- /dev/null +++ b/tensorflow/tools/dist_test/python/benchmark_util.py @@ -0,0 +1,77 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Provides helper functions for distributed benchmarks running on Jenkins.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import calendar +from collections import namedtuple +import os + +from google.protobuf import json_format + +from tensorflow.core.util import test_log_pb2 +from tensorflow.python.platform import gfile + + +_OUTPUT_FILE_ENV_VAR = 'TF_DIST_BENCHMARK_RESULTS_FILE' +_TEST_NAME_ENV_VAR = 'TF_DIST_BENCHMARK_NAME' + + +# Represents a single timing entry where +# - name is a string +# - timing is the latency to track (for e.g. mean time per iter) +# - iters is the number of iterations +TimingEntry = namedtuple( + 'TimingEntry', ['name', 'timing', 'iters']) + + +def store_data_in_json(timing_entries, start_time, output_file=None): + """Stores benchmark results in JSON format. + + Args: + timing_entries: list of TimingEntry objects. + start_time: (datetime) start time of the test run. + output_file: if specified, writes benchmark results to output_file. + If not specified, writes results to the file specified by + BENCHMARK_RESULTS_FILE environment variable. + + Raises: + ValueError: when neither output_file is passed in nor + BENCHMARK_RESULTS_FILE is set. + """ + test_result = test_log_pb2.TestResults( + start_time=calendar.timegm(start_time.timetuple())) + if not output_file: + if _OUTPUT_FILE_ENV_VAR not in os.environ: + raise ValueError('Could not determine location to store results at.') + output_file = os.environ[_OUTPUT_FILE_ENV_VAR] + + with gfile.Open(output_file, 'wb') as jsonfile: + if _TEST_NAME_ENV_VAR in os.environ: + test_result.name = os.environ['POD_NAME_PREFIX'] + else: + test_result.name = 'TestBenchmark' + + for timing_entry in timing_entries: + test_result.entries.entry.add( + name=timing_entry.name, + iters=timing_entry.iters, + wall_time=timing_entry.timing + ) + json_test_results = json_format.MessageToJson(test_result) + jsonfile.write(json_test_results) diff --git a/tensorflow/tools/dist_test/python/benchmark_util_test.py b/tensorflow/tools/dist_test/python/benchmark_util_test.py new file mode 100644 index 0000000000..e8a71f3153 --- /dev/null +++ b/tensorflow/tools/dist_test/python/benchmark_util_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.tools.dist_test.python.benchmark_util.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import datetime +import json +import os + +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.tools.dist_test.python import benchmark_util + + +class BenchmarkUtilTest(googletest.TestCase): + + def testStoreDataWithNoEntries(self): + output_file = os.path.join(test.get_temp_dir(), 'test_output1.json') + timing_entries = [] + benchmark_util.store_data_in_json( + timing_entries, datetime.date(2017, 1, 1), output_file) + json_output = json.loads(open(output_file, 'r').read()) + self.assertEquals('TestBenchmark', json_output['name']) + self.assertEquals(u'1483228800', json_output['startTime']) + + def testStoreDataWithEntries(self): + output_file = os.path.join(test.get_temp_dir(), 'test_output2.json') + timing_entries = [ + benchmark_util.TimingEntry('test', 0.1, 1)] + benchmark_util.store_data_in_json( + timing_entries, datetime.date(2017, 1, 1), output_file) + json_output = json.loads(open(output_file, 'r').read()) + + self.assertEquals(1, len(json_output['entries']['entry'])) + self.assertEquals('test', json_output['entries']['entry'][0]['name']) + self.assertEquals(0.1, json_output['entries']['entry'][0]['wallTime']) + self.assertEquals(u'1', json_output['entries']['entry'][0]['iters']) + self.assertEquals(u'1483228800', json_output['startTime']) + self.assertEquals('TestBenchmark', json_output['name']) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/dist_test/scripts/BUILD b/tensorflow/tools/dist_test/scripts/BUILD new file mode 100644 index 0000000000..abcd521a26 --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/BUILD @@ -0,0 +1,21 @@ +# Tools for running distributed benchmarks. + +licenses(["notice"]) # Apache 2.0 + +exports_files(["k8s_tensorflow.py"]) + +py_library( + name = "k8s_tensorflow_lib", + srcs = ["k8s_tensorflow_lib.py"], + srcs_version = "PY2AND3", +) + +py_test( + name = "k8s_tensorflow_test", + srcs = ["k8s_tensorflow_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":k8s_tensorflow_lib", + "//tensorflow/python:client_testlib", + ], +) diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py index 3693c177e2..b325f030e3 100755 --- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py @@ -25,6 +25,8 @@ from __future__ import print_function import argparse import sys +import k8s_tensorflow_lib + # Note: It is intentional that we do not import tensorflow in this script. The # machine that launches a TensorFlow k8s cluster does not have to have the # Python package of TensorFlow installed on it. @@ -33,125 +35,6 @@ import sys DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server' DEFAULT_PORT = 2222 -# TODO(cais): Consider adding resource requests/limits to the pods. - -# Worker pods will mount host volume /shared, as a convenient way to create -# shared storage among workers during local tests. -WORKER_RC = ( - """apiVersion: v1 -kind: ReplicationController -metadata: - name: {name_prefix}-worker{worker_id} -spec: - replicas: 1 - template: - metadata: - labels: - tf-worker: "{worker_id}" - name-prefix: "{name_prefix}" - job: "worker" - spec: - containers: - - name: tf-worker{worker_id} - image: {docker_image} - args: - - --cluster_spec={cluster_spec} - - --job_name=worker - - --task_id={worker_id} - ports: - - containerPort: {port} - env: - - name: POD_NAME_PREFIX - value: {name_prefix} - volumeMounts: [{volume_mounts}] - volumes: [{volumes}] -""") -WORKER_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: {name_prefix}-worker{worker_id} - labels: - tf-worker: "{worker_id}" -spec: - ports: - - port: {port} - targetPort: {port} - selector: - tf-worker: "{worker_id}" -""") -WORKER_LB_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: {name_prefix}-worker{worker_id} - labels: - tf-worker: "{worker_id}" -spec: - type: LoadBalancer - ports: - - port: {port} - selector: - tf-worker: "{worker_id}" -""") -PARAM_SERVER_RC = ( - """apiVersion: v1 -kind: ReplicationController -metadata: - name: {name_prefix}-ps{param_server_id} -spec: - replicas: 1 - template: - metadata: - labels: - tf-ps: "{param_server_id}" - name-prefix: "{name_prefix}" - job: "ps" - spec: - containers: - - name: tf-ps{param_server_id} - image: {docker_image} - args: - - --cluster_spec={cluster_spec} - - --job_name=ps - - --task_id={param_server_id} - ports: - - containerPort: {port} - env: - - name: POD_NAME_PREFIX - value: {name_prefix} - volumeMounts: [{volume_mounts}] - volumes: [{volumes}] -""") -PARAM_SERVER_SVC = ( - """apiVersion: v1 -kind: Service -metadata: - name: {name_prefix}-ps{param_server_id} - labels: - tf-ps: "{param_server_id}" -spec: - ports: - - port: {port} - selector: - tf-ps: "{param_server_id}" -""") -PARAM_LB_SVC = ("""apiVersion: v1 -kind: Service -metadata: - name: {name_prefix}-ps{param_server_id} - labels: - tf-ps: "{param_server_id}" -spec: - type: LoadBalancer - ports: - - port: {port} - selector: - tf-ps: "{param_server_id}" -""") -VOLUME_MOUNTS = '{name: shared, mountPath: /shared}' -VOLUMES = '{name: shared, hostPath: {path: /shared}}' - def main(): """Do arg parsing.""" @@ -204,108 +87,17 @@ def main(): sys.exit(1) # Generate contents of yaml config - yaml_config = GenerateConfig(args.num_workers, - args.num_parameter_servers, - args.grpc_port, - args.request_load_balancer, - args.docker_image, - args.name_prefix, - args.use_shared_volume) + yaml_config = k8s_tensorflow_lib.GenerateConfig( + args.num_workers, + args.num_parameter_servers, + args.grpc_port, + args.request_load_balancer, + args.docker_image, + args.name_prefix, + env_vars=None, + use_shared_volume=args.use_shared_volume) print(yaml_config) # pylint: disable=superfluous-parens -def GenerateConfig(num_workers, - num_param_servers, - port, - request_load_balancer, - docker_image, - name_prefix, - use_shared_volume): - """Generate configuration strings.""" - config = '' - for worker in range(num_workers): - config += WORKER_RC.format( - port=port, - worker_id=worker, - docker_image=docker_image, - name_prefix=name_prefix, - volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', - volumes=VOLUMES if use_shared_volume else '', - cluster_spec=WorkerClusterSpecString(num_workers, - num_param_servers, - port, - name_prefix)) - config += '---\n' - if request_load_balancer: - config += WORKER_LB_SVC.format(port=port, - worker_id=worker, - name_prefix=name_prefix) - else: - config += WORKER_SVC.format(port=port, - worker_id=worker, - name_prefix=name_prefix) - config += '---\n' - - for param_server in range(num_param_servers): - config += PARAM_SERVER_RC.format( - port=port, - param_server_id=param_server, - docker_image=docker_image, - name_prefix=name_prefix, - volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', - volumes=VOLUMES if use_shared_volume else '', - cluster_spec=ParamServerClusterSpecString(num_workers, - num_param_servers, - port, - name_prefix)) - config += '---\n' - if request_load_balancer: - config += PARAM_LB_SVC.format( - port=port, param_server_id=param_server, name_prefix=name_prefix) - else: - config += PARAM_SERVER_SVC.format( - port=port, param_server_id=param_server, name_prefix=name_prefix) - config += '---\n' - - return config - - -def WorkerClusterSpecString(num_workers, - num_param_servers, - port, - name_prefix): - """Generates worker cluster spec.""" - return ClusterSpecString(num_workers, num_param_servers, port, name_prefix) - - -def ParamServerClusterSpecString(num_workers, - num_param_servers, - port, - name_prefix): - """Generates parameter server spec.""" - return ClusterSpecString(num_workers, num_param_servers, port, - name_prefix) - - -def ClusterSpecString(num_workers, - num_param_servers, - port, - name_prefix): - """Generates general cluster spec.""" - spec = 'worker|' - for worker in range(num_workers): - spec += '%s-worker%d:%d' % (name_prefix, worker, port) - if worker != num_workers-1: - spec += ';' - - spec += ',ps|' - for param_server in range(num_param_servers): - spec += '%s-ps%d:%d' % (name_prefix, param_server, port) - if param_server != num_param_servers-1: - spec += ';' - - return spec - - if __name__ == '__main__': main() diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py new file mode 100644 index 0000000000..8adbe387ba --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_lib.py @@ -0,0 +1,309 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Generates YAML configuration files for distributed TensorFlow workers. + +The workers will be run in a Kubernetes (k8s) container cluster. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Note: It is intentional that we do not import tensorflow in this script. The +# machine that launches a TensorFlow k8s cluster does not have to have the +# Python package of TensorFlow installed on it. + +# TODO(cais): Consider adding resource requests/limits to the pods. + +# Worker pods will mount host volume /shared, as a convenient way to create +# shared storage among workers during local tests. +WORKER_RC = ( + """apiVersion: v1 +kind: ReplicationController +metadata: + name: {name_prefix}-worker{worker_id} +spec: + replicas: 1 + template: + metadata: + labels: + tf-worker: "{worker_id}" + name-prefix: "{name_prefix}" + job: "worker" + spec: + containers: + - name: tf-worker{worker_id} + image: {docker_image} + args: [{args}] + ports: + - containerPort: {port} + env: [{env_vars}] + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] +""") +WORKER_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-worker{worker_id} + labels: + tf-worker: "{worker_id}" +spec: + ports: + - port: {port} + targetPort: {port} + selector: + tf-worker: "{worker_id}" +""") +WORKER_LB_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-worker{worker_id} + labels: + tf-worker: "{worker_id}" +spec: + type: LoadBalancer + ports: + - port: {port} + selector: + tf-worker: "{worker_id}" +""") +PARAM_SERVER_RC = ( + """apiVersion: v1 +kind: ReplicationController +metadata: + name: {name_prefix}-ps{param_server_id} +spec: + replicas: 1 + template: + metadata: + labels: + tf-ps: "{param_server_id}" + name-prefix: "{name_prefix}" + job: "ps" + spec: + containers: + - name: tf-ps{param_server_id} + image: {docker_image} + args: [{args}] + ports: + - containerPort: {port} + env: [{env_vars}] + volumeMounts: [{volume_mounts}] + volumes: [{volumes}] +""") +PARAM_SERVER_SVC = ( + """apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-ps{param_server_id} + labels: + tf-ps: "{param_server_id}" +spec: + ports: + - port: {port} + selector: + tf-ps: "{param_server_id}" +""") +PARAM_LB_SVC = ("""apiVersion: v1 +kind: Service +metadata: + name: {name_prefix}-ps{param_server_id} + labels: + tf-ps: "{param_server_id}" +spec: + type: LoadBalancer + ports: + - port: {port} + selector: + tf-ps: "{param_server_id}" +""") +VOLUME_MOUNTS = '{name: shared, mountPath: /shared}' +VOLUMES = '{name: shared, hostPath: {path: /shared}}' +_ENV_VAR_TEMPLATE = '{name: "%s", value: "%s"}' +_ARG_TEMPLATE = '"--%s=%s"' + + +def GenerateConfig(num_workers, + num_param_servers, + port, + request_load_balancer, + docker_image, + name_prefix, + env_vars=None, + use_shared_volume=True, + use_cluster_spec=True): + """Generate configuration strings. + + Args: + num_workers: number of worker jobs. + num_param_servers: number of ps server jobs. + port: GRPC server port. + request_load_balancer: request worker0 to be exposed on a public IP + address via an external load balancer. + docker_image: docker image to use. + name_prefix: name to prepend to pod job names. + env_vars: dictionary of environment variables to set. + use_shared_volume: whether to add hostPath to /shared directory + to the kubernetes config. + use_cluster_spec: if true, pass --cluster_spec to worker and ps jobs. + If false, pass --worker_hosts and --ps_hosts to worker and ps jobs. + + Returns: + Kubernetes yaml config. + """ + if env_vars is None: + env_vars = {} + env_str = ', '.join([_ENV_VAR_TEMPLATE % (name, value) + for name, value in env_vars.items()]) + config = '' + common_args = GetCommonArgs( + num_workers, num_param_servers, port, name_prefix, use_cluster_spec) + for worker in range(num_workers): + worker_args = { + 'job_name': 'worker', + 'task_id': worker + } + worker_args.update(common_args) + arg_str = ', '.join([_ARG_TEMPLATE % (name, value) + for name, value in worker_args.items()]) + config += WORKER_RC.format( + port=port, + worker_id=worker, + docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', + args=arg_str, + env_vars=env_str) + config += '---\n' + if request_load_balancer: + config += WORKER_LB_SVC.format(port=port, + worker_id=worker, + name_prefix=name_prefix) + else: + config += WORKER_SVC.format(port=port, + worker_id=worker, + name_prefix=name_prefix) + config += '---\n' + + for param_server in range(num_param_servers): + ps_args = { + 'job_name': 'ps', + 'task_id': param_server + } + ps_args.update(common_args) + arg_str = ', '.join([_ARG_TEMPLATE % (name, value) + for name, value in ps_args.items()]) + config += PARAM_SERVER_RC.format( + port=port, + param_server_id=param_server, + docker_image=docker_image, + name_prefix=name_prefix, + volume_mounts=VOLUME_MOUNTS if use_shared_volume else '', + volumes=VOLUMES if use_shared_volume else '', + args=arg_str, + env_vars=env_str) + config += '---\n' + if request_load_balancer: + config += PARAM_LB_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) + else: + config += PARAM_SERVER_SVC.format( + port=port, param_server_id=param_server, name_prefix=name_prefix) + config += '---\n' + + return config + + +def WorkerClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates worker cluster spec.""" + return ClusterSpecString(num_workers, num_param_servers, port, name_prefix) + + +def ParamServerClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates parameter server spec.""" + return ClusterSpecString(num_workers, num_param_servers, port, + name_prefix) + + +def ClusterSpecString(num_workers, + num_param_servers, + port, + name_prefix): + """Generates general cluster spec.""" + spec = 'worker|' + for worker in range(num_workers): + spec += '%s-worker%d:%d' % (name_prefix, worker, port) + if worker != num_workers-1: + spec += ';' + + spec += ',ps|' + for param_server in range(num_param_servers): + spec += '%s-ps%d:%d' % (name_prefix, param_server, port) + if param_server != num_param_servers-1: + spec += ';' + + return spec + + +def GetCommonArgs(num_workers, + num_param_servers, + port, + name_prefix, + use_cluster_spec): + """Get arguments common to both worker and ps jobs. + + Args: + num_workers: number of workers. + num_param_servers: number of ps servers. + port: worker and ps port number. + name_prefix: prefix to prepend to job names. + use_cluster_spec: if true, pass --cluster_spec argument. + If false, parse --worker_hosts and --ps_hosts arguments. + + Returns: + A dictionary of argument names mapping to argument values. + """ + common_args = {} + if use_cluster_spec: + common_args['cluster_spec'] = WorkerClusterSpecString( + num_workers, + num_param_servers, + port, + name_prefix) + else: + common_args['worker_hosts'] = WorkerHosts(num_workers, port, name_prefix) + common_args['ps_hosts'] = PsHosts(num_param_servers, port, name_prefix) + return common_args + + +def WorkerHosts(num_workers, port, name_prefix): + worker_hosts = ['%s-worker%d:%d' % (name_prefix, i, port) + for i in range(num_workers)] + return ','.join(worker_hosts) + + +def PsHosts(num_ps, port, name_prefix): + ps_hosts = ['%s-ps%d:%d' % (name_prefix, i, port) + for i in range(num_ps)] + return ','.join(ps_hosts) diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py new file mode 100644 index 0000000000..7d9b3f83f5 --- /dev/null +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow_test.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.tools.dist_test.scripts.k8s_tensorflow_lib.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import googletest +from tensorflow.tools.dist_test.scripts import k8s_tensorflow_lib + + +class K8sTensorflowTest(googletest.TestCase): + + def testGenerateConfig_LoadBalancer(self): + # Use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertTrue('LoadBalancer' in config) + + # Don't use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertFalse('LoadBalancer' in config) + + def testGenerateConfig_SharedVolume(self): + # Use shared directory + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=True) + self.assertTrue('/shared' in config) + + # Don't use shared directory + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=False, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False) + self.assertFalse('/shared' in config) + + def testEnvVar(self): + # Use loadbalancer + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + env_vars={'test1': 'test1_value', 'test2': 'test2_value'}) + self.assertTrue('{name: "test1", value: "test1_value"}' in config) + self.assertTrue('{name: "test2", value: "test2_value"}' in config) + + def testClusterSpec(self): + # Use cluster_spec + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + use_cluster_spec=True) + self.assertFalse('worker_hosts' in config) + self.assertFalse('ps_hosts' in config) + self.assertTrue( + '"--cluster_spec=worker|abc-worker0:5000,ps|abc-ps0:5000"' in config) + + # Don't use cluster_spec + config = k8s_tensorflow_lib.GenerateConfig( + num_workers=1, + num_param_servers=1, + port=5000, + request_load_balancer=True, + docker_image='test_image', + name_prefix='abc', + use_shared_volume=False, + use_cluster_spec=False) + self.assertFalse('cluster_spec' in config) + self.assertTrue('"--worker_hosts=abc-worker0:5000"' in config) + self.assertTrue('"--ps_hosts=abc-ps0:5000"' in config) + + def testWorkerHosts(self): + self.assertEquals( + 'test_prefix-worker0:1234', + k8s_tensorflow_lib.WorkerHosts(1, 1234, 'test_prefix')) + self.assertEquals( + 'test_prefix-worker0:1234,test_prefix-worker1:1234', + k8s_tensorflow_lib.WorkerHosts(2, 1234, 'test_prefix')) + + def testPsHosts(self): + self.assertEquals( + 'test_prefix-ps0:1234,test_prefix-ps1:1234', + k8s_tensorflow_lib.PsHosts(2, 1234, 'test_prefix')) + + +if __name__ == '__main__': + googletest.main() |