From 65b010308c2ab3f365b5b9b40dd56591b179b996 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 14 Sep 2016 19:00:00 -0800 Subject: Update & fix OSS distributed TF tests: mnist_replica 1) Replace the old and breaking docker-in-docker local test with a single-instance, multi-process test, built upon GitHub PR https://github.com/tensorflow/tensorflow/pull/3935 This simplifies the local test and makes it less susceptible to future changes in docker-in-docker support by docker. 2) Adding --existing_servers flag to mnist_replica.py and associated bash scripts, so that we can distinguish a) the case in which we want to create in-process servers and supervisors (as in the new local_test.sh), and b) the case in which GRPC TF servers are already created and we just want to connect to the workers (as in remote_test.sh). 3) Rename some flags in bash script to improve consistency with the mnist_replica.py. 4) Related doc changes in README.md. Change: 133209130 --- tensorflow/tools/dist_test/Dockerfile.local | 53 ++-- tensorflow/tools/dist_test/README.md | 31 +-- tensorflow/tools/dist_test/local_test.sh | 137 ++-------- tensorflow/tools/dist_test/python/mnist_replica.py | 289 +++++++++++---------- tensorflow/tools/dist_test/remote_test.sh | 24 +- .../tools/dist_test/scripts/create_tf_cluster.sh | 51 +++- .../tools/dist_test/scripts/dist_mnist_test.sh | 96 ++++--- tensorflow/tools/dist_test/scripts/dist_test.sh | 63 +++-- .../tools/dist_test/scripts/k8s_tensorflow.py | 19 +- 9 files changed, 390 insertions(+), 373 deletions(-) (limited to 'tensorflow/tools/dist_test') diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local index e23fa034a3..05da1e92d2 100644 --- a/tensorflow/tools/dist_test/Dockerfile.local +++ b/tensorflow/tools/dist_test/Dockerfile.local @@ -1,24 +1,41 @@ -FROM jpetazzo/dind +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Docker image for testing distributed (GRPC) TensorFlow on a single machine. +# +# See ./local_test.sh for usage example. -MAINTAINER Shanqing Cai +FROM ubuntu:16.04 -RUN apt-get update +MAINTAINER Shanqing Cai -RUN apt-get install -y --no-install-recommends \ - build-essential \ - dbus \ - git \ - software-properties-common +# Pick up some TF dependencies. +RUN apt-get update && apt-get install -y \ + curl \ + python-numpy \ + python-pip \ + && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* -# Install the latest golang -RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz -RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz -RUN rm -f go1.4.2.linux-amd64.tar.gz -RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc +RUN curl -O https://bootstrap.pypa.io/get-pip.py && \ + python get-pip.py && \ + rm get-pip.py -# Create shared storage on host. k8s pods (docker containers) created on the -# host can share it and all have read/write access. -RUN mkdir /shared -RUN chmod 666 /shared +# Install TensorFlow CPU version from nightly build. +RUN pip --no-cache-dir install \ + https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl -ADD . /var/tf-k8s +ADD . /var/tf_dist_test diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md index b042bcf3a2..91f64dd9c3 100644 --- a/tensorflow/tools/dist_test/README.md +++ b/tensorflow/tools/dist_test/README.md @@ -4,29 +4,18 @@ runtime in TensorFlow. There are three general modes of testing: -**1) Launch a local Kubernetes (k8s) cluster and run the test suites on it** +**1) Launch a docker container and run parameters servers and workers as + separate processes therein.** For example: ./local_test.sh -This option makes use of the docker-in-docker (dind) containers. It requires -the docker0 network interface to be set to the promiscuous mode on the host: - - sudo ip link set docker0 promisc on - -The environment variable "TF_DIST_SERVER_DOCKER_IMAGE" can be used to override -the Docker image used to generate the TensorFlow GRPC server pods -("tensorflow/tf_grpc_test_server"). For example: - - export TF_DIST_SERVER_DOCKER_IMAGE= - ./local_test.sh - By default, local_test.sh runs the MNIST-with-replicas model as a test. -However, you can use the --model-name flag to run the tf-learn/wide&deep +However, you can use the --model_name flag to run the tf-learn/wide&deep cesnsu model: - ./local_test.sh --model-name CENSUS_WIDENDEEP + ./local_test.sh --model_name CENSUS_WIDENDEEP **2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the test suite on it** @@ -36,7 +25,7 @@ For example: export TF_DIST_GCLOUD_PROJECT="tensorflow-testing" export TF_DIST_GCLOUD_COMPUTE_ZONE="us-central1-f" export TF_DIST_CONTAINER_CLUSTER="test-cluster-1" - export TF_DIST_GCLOUD_KEY_FILE_DIR="/tmp/gcloud-secrets" + export TF_DIST_GCLOUD_KEY_FILE="/var/gcloud-secrets/my-gcloud-key.json" ./remote_test.sh Here you specify the Google Compute Engine (GCE) project, compute zone and @@ -46,7 +35,7 @@ the JSON service account key file named "tensorflow-testing.json" is located. You can use the flag "--setup-cluster-only" to perform only the cluster setup step and skip the testing step: - ./remote_test.sh --setup-cluster-only + ./remote_test.sh --setup_cluster_only **3) Run the test suite on an existing k8s TensorFlow cluster** @@ -73,10 +62,10 @@ from the model replicas before the update is applied to the model parameters. To use this mode, do: # For remote testing - ./remote_test.sh --sync-replicas + ./remote_test.sh --sync_replicas # For local testing - ./local_test.sh --sync-replicas + ./local_test.sh --sync_replicas **Specifying the number of workers** @@ -85,10 +74,10 @@ You can specify the number of workers by using the --num-workers option flag, e.g., # For remote testing - ./remote_test.sh --num-workers 4 + ./remote_test.sh --num_workers 4 # For local testing - ./local_test.sh --num-workers 4 + ./local_test.sh --num_workers 4 **Building the GRPC server Docker image** diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh index be616b7e24..727258c6d8 100755 --- a/tensorflow/tools/dist_test/local_test.sh +++ b/tensorflow/tools/dist_test/local_test.sh @@ -24,33 +24,31 @@ # 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container # and run the distributed test suite. # -# Usage: local_test.sh [--leave-container-running] -# [--model-name ] -# [--num-workers ] -# [--num-parameter-servers ] -# [--sync-replicas] +# Usage: local_test.sh [--leave_container_running] +# [--model_name ] +# [--num_workers ] +# [--num_parameter_servers ] +# [--sync_replicas] # -# E.g., local_test.sh --model-name CENSUS_WIDENDEEP -# local_test.sh --num-workers 3 --num-parameter-servers 3 +# E.g., local_test.sh --model_name CENSUS_WIDENDEEP +# local_test.sh --num_workers 3 --num_parameter_servers 3 # # Arguments: -# --leave-container-running: Do not stop the docker-in-docker container after +# --leave_container_running: Do not stop the docker-in-docker container after # the termination of the tests, e.g., for debugging # -# --num-workers : +# --num_workers : # Specifies the number of worker pods to start # -# --num-parameter-server : +# --num_parameter_server : # Specifies the number of parameter servers to start # -# --sync-replicas +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. # # In addition, this script obeys the following environment variables: -# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch -# TensorFlow (GRPC) servers with # TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images @@ -72,20 +70,20 @@ MODEL_NAME="" MODEL_NAME_FLAG="" NUM_WORKERS=2 NUM_PARAMETER_SERVERS=2 -SYNC_REPLICAS=0 +SYNC_REPLICAS_FLAG="" while true; do - if [[ $1 == "--leave-container-running" ]]; then + if [[ $1 == "--leave_container_running" ]]; then LEAVE_CONTAINER_RUNNING=1 - elif [[ $1 == "--model-name" ]]; then + elif [[ $1 == "--model_name" ]]; then MODEL_NAME="$2" - MODEL_NAME_FLAG="--model-name ${MODEL_NAME}" - elif [[ $1 == "--num-workers" ]]; then + MODEL_NAME_FLAG="--model_name ${MODEL_NAME}" + elif [[ $1 == "--num_workers" ]]; then NUM_WORKERS=$2 - elif [[ $1 == "--num-parameter-servers" ]]; then + elif [[ $1 == "--num_parameter_servers" ]]; then NUM_PARAMETER_SERVERS=$2 - elif [[ $1 == "--sync-replicas" ]]; then - SYNC_REPLICAS=1 + elif [[ $1 == "--sync_replicas" ]]; then + SYNC_REPLICAS_FLAG="--sync_replicas" fi shift @@ -98,7 +96,7 @@ echo "LEAVE_CONTAINER_RUNNING: ${LEAVE_CONTAINER_RUNNING}" echo "MODEL_NAME: \"${MODEL_NAME}\"" echo "NUM_WORKERS: ${NUM_WORKERS}" echo "NUM_PARAMETER_SERVERS: ${NUM_PARAMETER_SERVERS}" -echo "SYNC_REPLICAS: \"${SYNC_REPLICAS}\"" +echo "SYNC_REPLICAS_FLAG: \"${SYNC_REPLICAS_FLAG}\"" # Current script directory DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -122,90 +120,11 @@ if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] && fi docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \ - -f ${DIR}/Dockerfile.local ${DIR} - - -# Attempt to start the docker container with docker, which will run the k8s -# cluster inside. - -# Get current script directory -CONTAINER_START_LOG=$(mktemp --suffix=.log) -echo "Log file for starting cluster container: ${CONTAINER_START_LOG}" -echo "" - -${DIR}/local/start_tf_cluster_container.sh \ - ${LOCAL_K8S_CACHE} \ - ${DOCKER_IMG_NAME} | \ - tee ${CONTAINER_START_LOG} & - -# Poll start log until the k8s service is started properly or when maximum -# attempt count is reached. -MAX_SERVER_POLLING_ATTEMPTS=600 - -echo "Waiting for docker-in-docker container for local k8s TensorFlow "\ -"cluster to start and launch Kubernetes..." - -COUNTER=0 -while true; do - sleep 1 - - ((COUNTER++)) - if [[ "${COUNTER}" -ge "${MAX_SERVER_POLLING_ATTEMPTS}" ]]; then - die "Reached maximum number of attempts (${MAX_SERVER_POLLING_ATTEMPTS}) "\ -"while waiting for docker-in-docker for local k8s TensorFlow cluster to start" - fi - - # Check for hitting max attempt while trying to start docker-in-docker - if [[ $(grep -i "Reached maximum number of attempts" \ - "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then - die "Docker-in-docker container for local k8s TensorFlow cluster "\ -"FAILED to start" - fi - - if [[ $(grep -i "Local Kubernetes cluster is running" \ - "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then - break - fi -done - -# Determine the id of the docker-in-docker container -DIND_ID=$(get_container_id_by_image_name ${DOCKER_IMG_NAME}) - -echo "Docker-in-docker container for local k8s TensorFlow cluster has been "\ -"started successfully." -echo "Docker-in-docker container ID: ${DIND_ID}" -echo "Launching k8s tf cluster and tests in container ${DIND_ID} ..." -echo "" - -# Launch k8s tf cluster in the docker-in-docker container and perform tests -SYNC_REPLICAS_FLAG="" -if [[ ${SYNC_REPLICAS} == "1" ]]; then - SYNC_REPLICAS_FLAG="--sync-replicas" -fi - -docker exec ${DIND_ID} \ - /var/tf-k8s/local/test_local_tf_cluster.sh \ - ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} \ - ${MODEL_NAME_FLAG} ${SYNC_REPLICAS_FLAG} -TEST_RES=$? - -# Tear down: stop docker-in-docker container -if [[ ${LEAVE_CONTAINER_RUNNING} == "0" ]]; then - echo "" - echo "Stopping docker-in-docker container ${DIND_ID}" - - docker stop --time=1 ${DIND_ID} || \ - echo "WARNING: Failed to stop container ${DIND_ID} !!" - - echo "" -else - echo "Will NOT terminate DIND container ${DIND_ID}" -fi - -if [[ "${TEST_RES}" != "0" ]]; then - die "Test of distributed TensorFlow runtime on docker-in-docker local "\ -"k8s cluster FAILED" -else - echo "Test of distributed TensorFlow runtime on docker-in-docker local "\ -"k8s cluster PASSED" -fi + -f ${DIR}/Dockerfile.local ${DIR} || \ + die "Failed to build docker image: ${DOCKER_IMG_NAME}" + +docker run ${DOCKER_IMG_NAME} \ + /var/tf_dist_test/scripts/dist_mnist_test.sh \ + --ps_hosts "localhost:2000,localhost:2001" \ + --worker_hosts "localhost:3000,localhost:3001" \ + --num_gpus 0 ${SYNC_REPLICAS_FLAG} diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py index 9bd79c8e9c..0f642d5e69 100644 --- a/tensorflow/tools/dist_test/python/mnist_replica.py +++ b/tensorflow/tools/dist_test/python/mnist_replica.py @@ -73,9 +73,14 @@ flags.DEFINE_boolean("sync_replicas", False, "Use the sync_replicas (synchronized replicas) mode, " "wherein the parameter updates from workers are aggregated " "before applied to avoid stale gradients") +flags.DEFINE_boolean( + "existing_servers", False, "Whether servers already exists. If True, " + "will use the worker hosts via their GRPC URLs (one client process " + "per worker host). Otherwise, will create an in-process TensorFlow " + "server.") flags.DEFINE_string("ps_hosts","localhost:2222", "Comma-separated list of hostname:port pairs") -flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", +flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "Comma-separated list of hostname:port pairs") flags.DEFINE_string("job_name", None,"job name: worker or ps") @@ -97,156 +102,164 @@ def main(unused_argv): print("job name = %s" % FLAGS.job_name) print("task index = %d" % FLAGS.task_index) - + #Construct the cluster and start the server ps_spec = FLAGS.ps_hosts.split(",") worker_spec = FLAGS.worker_hosts.split(",") - # Get the number of workers + # Get the number of workers. num_workers = len(worker_spec) cluster = tf.train.ClusterSpec({ "ps": ps_spec, "worker": worker_spec}) - server = tf.train.Server(cluster, - job_name=FLAGS.job_name, - task_index=FLAGS.task_index) - - if FLAGS.job_name == "ps": - server.join() - elif FLAGS.job_name == "worker": - is_chief = (FLAGS.task_index == 0) - if FLAGS.num_gpus > 0: - if FLAGS.num_gpus < num_workers: - raise ValueError("number of gpus is less than number of workers") - # Avoid gpu allocation conflict: now allocate task_num -> #gpu - # for each worker in the corresponding machine - gpu = (FLAGS.task_index % FLAGS.num_gpus) - worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) - elif FLAGS.num_gpus == 0: - # Just allocate the CPU to worker server - cpu = 0 - worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) - # The device setter will automatically place Variables ops on separate - # parameter servers (ps). The non-Variable ops will be placed on the workers. - # The ps use CPU and workers use corresponding GPU - with tf.device(tf.train.replica_device_setter( - worker_device=worker_device, - ps_device="/job:ps/cpu:0", - cluster=cluster)): - global_step = tf.Variable(0, name="global_step", trainable=False) - - # Variables of the hidden layer - hid_w = tf.Variable( - tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], - stddev=1.0 / IMAGE_PIXELS), name="hid_w") - hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") - - # Variables of the softmax layer - sm_w = tf.Variable( - tf.truncated_normal([FLAGS.hidden_units, 10], - stddev=1.0 / math.sqrt(FLAGS.hidden_units)), - name="sm_w") - sm_b = tf.Variable(tf.zeros([10]), name="sm_b") - - # Ops: located on the worker specified with FLAGS.task_index - x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) - y_ = tf.placeholder(tf.float32, [None, 10]) - - hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) - hid = tf.nn.relu(hid_lin) - - y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) - cross_entropy = -tf.reduce_sum(y_ * - tf.log(tf.clip_by_value(y, 1e-10, 1.0))) - - opt = tf.train.AdamOptimizer(FLAGS.learning_rate) - - if FLAGS.sync_replicas: - if FLAGS.replicas_to_aggregate is None: - replicas_to_aggregate = num_workers - else: - replicas_to_aggregate = FLAGS.replicas_to_aggregate - - opt = tf.train.SyncReplicasOptimizer( - opt, - replicas_to_aggregate=replicas_to_aggregate, - total_num_replicas=num_workers, - replica_id=FLAGS.task_index, - name="mnist_sync_replicas") - - train_step = opt.minimize(cross_entropy, - global_step=global_step) - - if FLAGS.sync_replicas and is_chief: - # Initial token and chief queue runners required by the sync_replicas mode - chief_queue_runner = opt.get_chief_queue_runner() - init_tokens_op = opt.get_init_tokens_op() - - init_op = tf.initialize_all_variables() - train_dir = tempfile.mkdtemp() - sv = tf.train.Supervisor(is_chief=is_chief, - logdir=train_dir, - init_op=init_op, - recovery_wait_secs=1, - global_step=global_step) - - sess_config = tf.ConfigProto( - allow_soft_placement=True, - log_device_placement=False, - device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) - - # The chief worker (task_index==0) session will prepare the session, - # while the remaining workers will wait for the preparation to complete. - if is_chief: - print("Worker %d: Initializing session..." % FLAGS.task_index) - else: - print("Worker %d: Waiting for session to be initialized..." % - FLAGS.task_index) + if not FLAGS.existing_servers: + # Not using existing servers. Create an in-process server. + server = tf.train.Server( + cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index) + if FLAGS.job_name == "ps": + server.join() + + is_chief = (FLAGS.task_index == 0) + if FLAGS.num_gpus > 0: + if FLAGS.num_gpus < num_workers: + raise ValueError("number of gpus is less than number of workers") + # Avoid gpu allocation conflict: now allocate task_num -> #gpu + # for each worker in the corresponding machine + gpu = (FLAGS.task_index % FLAGS.num_gpus) + worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu) + elif FLAGS.num_gpus == 0: + # Just allocate the CPU to worker server + cpu = 0 + worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu) + # The device setter will automatically place Variables ops on separate + # parameter servers (ps). The non-Variable ops will be placed on the workers. + # The ps use CPU and workers use corresponding GPU + with tf.device( + tf.train.replica_device_setter( + worker_device=worker_device, + ps_device="/job:ps/cpu:0", + cluster=cluster)): + global_step = tf.Variable(0, name="global_step", trainable=False) + + # Variables of the hidden layer + hid_w = tf.Variable( + tf.truncated_normal( + [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units], + stddev=1.0 / IMAGE_PIXELS), + name="hid_w") + hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b") + + # Variables of the softmax layer + sm_w = tf.Variable( + tf.truncated_normal( + [FLAGS.hidden_units, 10], + stddev=1.0 / math.sqrt(FLAGS.hidden_units)), + name="sm_w") + sm_b = tf.Variable(tf.zeros([10]), name="sm_b") + + # Ops: located on the worker specified with FLAGS.task_index + x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS]) + y_ = tf.placeholder(tf.float32, [None, 10]) + + hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b) + hid = tf.nn.relu(hid_lin) + + y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b)) + cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) + + opt = tf.train.AdamOptimizer(FLAGS.learning_rate) + + if FLAGS.sync_replicas: + if FLAGS.replicas_to_aggregate is None: + replicas_to_aggregate = num_workers + else: + replicas_to_aggregate = FLAGS.replicas_to_aggregate + + opt = tf.train.SyncReplicasOptimizer( + opt, + replicas_to_aggregate=replicas_to_aggregate, + total_num_replicas=num_workers, + replica_id=FLAGS.task_index, + name="mnist_sync_replicas") + + train_step = opt.minimize(cross_entropy, global_step=global_step) + + if FLAGS.sync_replicas and is_chief: + # Initial token and chief queue runners required by the sync_replicas mode + chief_queue_runner = opt.get_chief_queue_runner() + init_tokens_op = opt.get_init_tokens_op() + + init_op = tf.initialize_all_variables() + train_dir = tempfile.mkdtemp() + sv = tf.train.Supervisor( + is_chief=is_chief, + logdir=train_dir, + init_op=init_op, + recovery_wait_secs=1, + global_step=global_step) + + sess_config = tf.ConfigProto( + allow_soft_placement=True, + log_device_placement=False, + device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index]) + + # The chief worker (task_index==0) session will prepare the session, + # while the remaining workers will wait for the preparation to complete. + if is_chief: + print("Worker %d: Initializing session..." % FLAGS.task_index) + else: + print("Worker %d: Waiting for session to be initialized..." % + FLAGS.task_index) + + if FLAGS.existing_servers: + server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index] + print("Using existing server at: %s" % server_grpc_url) + + sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config) + else: sess = sv.prepare_or_wait_for_session(server.target, config=sess_config) - print("Worker %d: Session initialization complete." % FLAGS.task_index) - - if FLAGS.sync_replicas and is_chief: - # Chief worker will start the chief queue runner and call the init op - print("Starting chief queue runner and running init_tokens_op") - sv.start_queue_runners(sess, [chief_queue_runner]) - sess.run(init_tokens_op) - - # Perform training - time_begin = time.time() - print("Training begins @ %f" % time_begin) - - local_step = 0 - while True: - # Training feed - batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) - train_feed = {x: batch_xs, - y_: batch_ys} - - _, step = sess.run([train_step, global_step], feed_dict=train_feed) - local_step += 1 - - now = time.time() - print("%f: Worker %d: training step %d done (global step: %d)" % - (now, FLAGS.task_index, local_step, step)) - - if step >= FLAGS.train_steps: - break - - time_end = time.time() - print("Training ends @ %f" % time_end) - training_time = time_end - time_begin - print("Training elapsed time: %f s" % training_time) - - # Validation feed - val_feed = {x: mnist.validation.images, - y_: mnist.validation.labels} - val_xent = sess.run(cross_entropy, feed_dict=val_feed) - print("After %d training step(s), validation cross entropy = %g" % - (FLAGS.train_steps, val_xent)) + print("Worker %d: Session initialization complete." % FLAGS.task_index) + + if FLAGS.sync_replicas and is_chief: + # Chief worker will start the chief queue runner and call the init op + print("Starting chief queue runner and running init_tokens_op") + sv.start_queue_runners(sess, [chief_queue_runner]) + sess.run(init_tokens_op) + + # Perform training + time_begin = time.time() + print("Training begins @ %f" % time_begin) + + local_step = 0 + while True: + # Training feed + batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size) + train_feed = {x: batch_xs, y_: batch_ys} + + _, step = sess.run([train_step, global_step], feed_dict=train_feed) + local_step += 1 + + now = time.time() + print("%f: Worker %d: training step %d done (global step: %d)" % + (now, FLAGS.task_index, local_step, step)) + + if step >= FLAGS.train_steps: + break + + time_end = time.time() + print("Training ends @ %f" % time_end) + training_time = time_end - time_begin + print("Training elapsed time: %f s" % training_time) + + # Validation feed + val_feed = {x: mnist.validation.images, y_: mnist.validation.labels} + val_xent = sess.run(cross_entropy, feed_dict=val_feed) + print("After %d training step(s), validation cross entropy = %g" % + (FLAGS.train_steps, val_xent)) if __name__ == "__main__": diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh index b662572418..1d4a52c4c2 100755 --- a/tensorflow/tools/dist_test/remote_test.sh +++ b/tensorflow/tools/dist_test/remote_test.sh @@ -20,23 +20,23 @@ # runs from within a container based on the image. # # Usage: -# remote_test.sh [--setup-cluster-only] -# [--num-workers ] -# [--num-parameter-servers ] -# [--sync-replicas] +# remote_test.sh [--setup_cluster_only] +# [--num_workers ] +# [--num_parameter_servers ] +# [--sync_replicas] # # Arguments: -# --setup-cluster-only: +# --setup_cluster_only: # Setup the TensorFlow k8s cluster only, and do not perform testing of # the distributed runtime. # -# --num-workers : +# --num_workers : # Specifies the number of worker pods to start # -# --num-parameter-server : +# --num_parameter_server : # Specifies the number of parameter servers to start # -# --sync-replicas +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. @@ -56,9 +56,7 @@ # TF_DIST_GRPC_SERVER_URL is empty, same below) # TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone. # TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster -# TF_DIST_GCLOUD_KEY_FILE_DIR: path to the host directory that contains -# the gloud service key file -# "tensorflow-testing.json" +# TF_DIST_GCLOUD_KEY_FILE: path to the gloud service JSON key file # TF_DIST_GRPC_PORT: port on which to create the TensorFlow GRPC # servers # TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images @@ -99,9 +97,9 @@ fi docker build ${NO_CACHE_FLAG} \ -t ${DOCKER_IMG_NAME} -f "${DIR}/Dockerfile" "${DIR}" -KEY_FILE_DIR=${TF_DIST_GCLOUD_KEY_FILE_DIR:-"${HOME}/gcloud-secrets"} +KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-"${HOME}/gcloud-secrets/tensorflow-testing.json"} -docker run --rm -v ${KEY_FILE_DIR}:/var/gcloud/secrets \ +docker run --rm -v ${KEY_FILE}:/var/gcloud/secrets/tensorflow-testing.json \ ${DOCKER_ENV_FLAGS} \ ${DOCKER_IMG_NAME} \ /var/tf-dist-test/scripts/dist_test.sh $@ diff --git a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh index b0e07588e8..69c459ec8c 100755 --- a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh +++ b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh @@ -167,10 +167,10 @@ fi "${KUBECTL_BIN}" create -f "${K8S_YAML}" # Wait for external IP of worker services to become available -get_tf_worker_external_ip() { - # Usage: gen_tf_worker_external_ip - # E.g., gen_tf_worker_external_ip 2 - echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker${1}" | \ +get_tf_external_ip() { + # Usage: gen_tf_worker_external_ip + # E.g., gen_tf_worker_external_ip ps 2 + echo $("${KUBECTL_BIN}" get svc | grep "^tf-${1}${2}" | \ awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+") } @@ -187,16 +187,16 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then "of tf-worker0 service to emerge" fi - EXTERN_IPS="" + WORKER_EXTERN_IPS="" WORKER_INDEX=0 - N_AVAILABLE_EXTERNAL_IPS=0 + N_AVAILABLE_WORKER_EXTERNAL_IPS=0 while true; do - SVC_EXTERN_IP=$(get_tf_worker_external_ip ${WORKER_INDEX}) + SVC_EXTERN_IP=$(get_tf_external_ip worker ${WORKER_INDEX}) if [[ ! -z "${SVC_EXTERN_IP}" ]]; then - EXTERN_IPS="${EXTERN_IPS} ${SVC_EXTERN_IP}" + WORKER_EXTERN_IPS="${WORKER_EXTERN_IPS} ${SVC_EXTERN_IP}" - ((N_AVAILABLE_EXTERNAL_IPS++)) + ((N_AVAILABLE_WORKER_EXTERNAL_IPS++)) fi ((WORKER_INDEX++)) @@ -205,16 +205,42 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then fi done - if [[ ${N_AVAILABLE_EXTERNAL_IPS} == ${NUM_WORKERS} ]]; then + PS_EXTERN_IPS="" + PS_INDEX=0 + N_AVAILABLE_PS_EXTERNAL_IPS=0 + while true; do + SVC_EXTERN_IP=$(get_tf_external_ip ps ${PS_INDEX}) + + if [[ ! -z "${SVC_EXTERN_IP}" ]]; then + PS_EXTERN_IPS="${PS_EXTERN_IPS} ${SVC_EXTERN_IP}" + + ((N_AVAILABLE_PS_EXTERNAL_IPS++)) + fi + + ((PS_INDEX++)) + if [[ ${PS_INDEX} == ${NUM_PARAMETER_SERVERS} ]]; then + break; + fi + done + + if [[ ${N_AVAILABLE_WORKER_EXTERNAL_IPS} == ${NUM_WORKERS} ]] && \ + [[ ${N_AVAILABLE_PS_EXTERNAL_IPS} == ${NUM_PARAMETER_SERVERS} ]]; then break; fi done GRPC_SERVER_URLS="" - for IP in ${EXTERN_IPS}; do + for IP in ${WORKER_EXTERN_IPS}; do GRPC_SERVER_URLS="${GRPC_SERVER_URLS} grpc://${IP}:${GRPC_PORT}" done - echo "GRPC URLs of tf-workers: ${GRPC_SERVER_URLS}" + + GRPC_PS_URLS="" + for IP in ${PS_EXTERN_IPS}; do + GRPC_PS_URLS="${GRPC_PS_URLS} grpc://${IP}:${GRPC_PORT}" + done + + echo "GRPC URLs of tf-worker instances: ${GRPC_SERVER_URLS}" + echo "GRPC URLs of tf-ps instances: ${GRPC_PS_URLS}" else echo "Waiting for tf pods to be all running..." @@ -251,3 +277,4 @@ fi echo "Cluster setup complete." +echo "" diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh index d95f524486..4f2cab22d9 100755 --- a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh +++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh @@ -19,24 +19,28 @@ # grpc pods and service set up. # # Usage: -# dist_mnist_test.sh [--ps-hosts ] -# [--worker-hosts ] -# [--num-gpus ] -# [--sync-replicas] +# dist_mnist_test.sh [--existing_servers (True|False)] +# [--ps_hosts ] +# [--worker_hosts ] +# [--num_gpus ] +# [--sync_replicas] # -# --sync-replicas +# --existing_servers +# Use TensorFlow GRPC servers that are already created and running. +# +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. # -# ps-hosts/worker-hosts is the list of IP addresses or the GRPC URLs of the ps/worker of +# ps_hosts/worker_hosts is the list of IP addresses or the GRPC URLs of the ps/worker of # the worker sessions, separated with "," # e.g., "localhost:2222,localhost:2223" # -# --num-gpus : +# --num_gpus : # Specifies the number of gpus to use # -# NOTES: +# NOTES: # If you have the error "$'\r': command not found" # Please run the command below to remove trailing '\r' character that causes the error: # sed -i 's/\r$//' dist_mnist_test.sh @@ -52,25 +56,33 @@ die() { } if [[ $# == "0" ]]; then - die "Usage: $0 [--ps-hosts ] [--worker-hosts ] "\ -"[--num-gpus ] [--sync-replicas]" + die "Usage: $0 [--ps_hosts ] [--worker_hosts ] "\ +"[--num_gpus ] [--sync_replicas]" fi # Process additional input arguments SYNC_REPLICAS=0 +N_GPUS=0 +EXISTING_SERVERS=False while true; do - if [[ "$1" == "--ps-hosts" ]]; then + if [[ "$1" == "--ps_hosts" ]]; then PS_HOSTS=$2 - elif [[ "$1" == "--worker-hosts" ]]; then + elif [[ "$1" == "--worker_hosts" ]]; then WORKER_HOSTS=$2 - elif [[ "$1" == "--num-gpus" ]]; then + elif [[ "$1" == "--existing_servers" ]]; then + EXISTING_SERVERS=$2 + if [[ "${EXISTING_SERVERS}" != "True" ]] && \ + [[ "${EXISTING_SERVERS}" != "False" ]]; then + die "Invalid value for --existing_servers: should be (True|False)" + fi + elif [[ "$1" == "--num_gpus" ]]; then N_GPUS=$2 - elif [[ "$1" == "--sync-replicas" ]]; then + elif [[ "$1" == "--sync_replicas" ]]; then SYNC_REPLICAS="1" - die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\ + die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\ "supported by this test yet." - # TODO(cais): Remove error message once sync-replicas is fully supported + # TODO(cais): Remove error message once sync_replicas is fully supported. fi shift 2 @@ -86,6 +98,7 @@ else SYNC_REPLICAS_FLAG="False" fi +echo "EXISTING_SERVERS = ${EXISTING_SERVERS}" echo "PS_HOSTS = ${PS_HOSTS}" echo "WORKER_HOSTS = ${WORKER_HOSTS}" echo "NUM_GPUS = ${N_GPUS}" @@ -105,6 +118,7 @@ PS_LOG_PREFIX="/tmp/ps" # First, download the data from a single process, to avoid race-condition # during data downloading +# Pre-download data files. timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ --ps_hosts="${PS_HOSTS}" \ --worker_hosts="${WORKER_HOSTS}" \ @@ -123,25 +137,30 @@ PS_ARRAY=$(echo ${PS_HOSTS} | awk -F "," '{for(i=1;i<=NF;i++){printf $i" "}}') # Run a number of ps in parallel. In general, we only set 1 ps. echo "${N_PS} ps process(es) running in parallel..." -IDX=0 -PS=($PS_HOSTS) -while true; do - timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ - --ps_hosts="${PS_HOSTS}" \ - --worker_hosts="${WORKER_HOSTS}" \ - --job_name="ps" \ - --task_index=${IDX} \ - --num_gpus=${N_GPUS} \ - --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${PS_LOG_PREFIX}${IDX}.log" & - echo "PS ${IDX}: " - echo " PS HOST: ${PS_ARRAY[IDX]}" - echo " log file: ${PS_LOG_PREFIX}${IDX}.log" - - ((IDX++)) - if [[ "${IDX}" == "${N_PS}" ]]; then - break - fi -done +if [[ ${EXISTING_SERVERS} == "False" ]]; then + echo "Hello" + # Create parameter servers. + IDX=0 + PS=($PS_HOSTS) + while true; do + python "${MNIST_REPLICA}" \ + --existing_servers="${EXISTING_SERVERS}" \ + --ps_hosts="${PS_HOSTS}" \ + --worker_hosts="${WORKER_HOSTS}" \ + --job_name="ps" \ + --task_index=${IDX} \ + --num_gpus=${N_GPUS} \ + --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${PS_LOG_PREFIX}${IDX}.log" & + echo "PS ${IDX}: " + echo " PS HOST: ${PS_ARRAY[IDX]}" + echo " log file: ${PS_LOG_PREFIX}${IDX}.log" + + ((IDX++)) + if [[ "${IDX}" == "${N_PS}" ]]; then + break + fi + done +fi # Get N_WORKERS by WORKER_HOSTS @@ -155,12 +174,14 @@ INDICES="" IDX=0 while true; do timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ + --existing_servers="${EXISTING_SERVERS}" \ --ps_hosts="${PS_HOSTS}" \ --worker_hosts="${WORKER_HOSTS}" \ --job_name="worker" \ --task_index=${IDX} \ --num_gpus=${N_GPUS} \ - --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${WKR_LOG_PREFIX}${IDX}.log" & + --train_steps=500 \ + --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${WKR_LOG_PREFIX}${IDX}.log" & echo "Worker ${IDX}: " echo " WORKER HOST: ${WORKER_ARRAY[IDX]}" echo " log file: ${WKR_LOG_PREFIX}${IDX}.log" @@ -171,9 +192,8 @@ while true; do if [[ "${IDX}" == "${N_WORKERS}" ]]; then break fi -done - +done # Poll until all final validation cross entropy values become available or diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh index 1d60aa518f..080ce1df5f 100755 --- a/tensorflow/tools/dist_test/scripts/dist_test.sh +++ b/tensorflow/tools/dist_test/scripts/dist_test.sh @@ -25,25 +25,25 @@ # TensorFlow ops. # # Usage: -# dist_test.sh [--setup-cluster-only] -# [--model-name (MNIST | CENSUS_WIDENDEEP)] -# [--num-workers ] -# [--num-parameter-servers ] -# [--sync-replicas] +# dist_test.sh [--setup_cluster_only] +# [--model_name (MNIST | CENSUS_WIDENDEEP)] +# [--num_workers ] +# [--num_parameter_servers ] +# [--sync_replicas] # -# --setup-cluster-only: +# --setup_cluster_only: # Lets the script only set up the k8s container network # -# --model-name +# --model_name # Name of the model to test. Default is MNIST. # # --num-workers : # Specifies the number of worker pods to start # -# --num-parameter-server : +# --num_parameter_servers : # Specifies the number of parameter servers to start # -# --sync-replicas +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. @@ -72,15 +72,15 @@ SYNC_REPLICAS=0 SETUP_CLUSTER_ONLY=0 while true; do - if [[ "$1" == "--model-name" ]]; then + if [[ "$1" == "--model_name" ]]; then MODEL_NAME=$2 - elif [[ "$1" == "--num-workers" ]]; then + elif [[ "$1" == "--num_workers" ]]; then NUM_WORKERS=$2 - elif [[ "$1" == "--num-parameter-servers" ]]; then + elif [[ "$1" == "--num_parameter_servers" ]]; then NUM_PARAMETER_SERVERS=$2 - elif [[ "$1" == "--sync-replicas" ]]; then + elif [[ "$1" == "--sync_replicas" ]]; then SYNC_REPLICAS=1 - elif [[ "$1" == "--setup-cluster-only" ]]; then + elif [[ "$1" == "--setup_cluster_only" ]]; then SETUP_CLUSTER_ONLY=1 fi shift @@ -132,17 +132,32 @@ else tee "${TMP}" || \ die "Creation of TensorFlow k8s cluster FAILED" - GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-workers: .*" | \ - sed -e 's/GRPC URLs of tf-workers://g') + GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \ + sed -e 's/GRPC URLs of tf-worker instances://g') + + GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \ + sed -e 's/GRPC URLs of tf-ps instances://g') if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then die "FAILED to determine GRPC server URLs of all workers" fi + if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then + die "FAILED to determine GRPC server URLs of all parameter servers" + fi + + WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \ + sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') + PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \ + sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') + + echo "WORKER_HOSTS = ${WORKER_HOSTS}" + echo "PS_HOSTS = ${PS_HOSTS}" + rm -f ${TMP} if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then echo "Skipping testing of distributed runtime due to "\ -"option flag --setup-cluster-only" +"option flag --setup_cluster_only" exit 0 fi fi @@ -158,17 +173,21 @@ test_MNIST() { return 1 fi - echo "Performing distributed MNIST training through grpc sessions @ "\ + echo "Performing distributed MNIST training through worker grpc sessions @ "\ "${GRPC_SERVER_URLS}..." + echo "and ps grpc sessions @ ${GRPC_PS_URLS}" + SYNC_REPLICAS_FLAG="" if [[ ${SYNC_REPLICAS} == "1" ]]; then - SYNC_REPLICAS_FLAG="--sync-replicas" + SYNC_REPLICAS_FLAG="--sync_replicas" fi - "${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \ - --num-workers "${NUM_WORKERS}" \ - --num-parameter-servers "${NUM_PARAMETER_SERVERS}" \ + "${MNIST_DIST_TEST_BIN}" \ + --existing_servers True \ + --ps_hosts "${PS_HOSTS}" \ + --worker_hosts "${WORKER_HOSTS}" \ + --num_gpus 0 \ ${SYNC_REPLICAS_FLAG} if [[ $? == "0" ]]; then diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py index 3a427a1d4e..854c6b832a 100755 --- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py +++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py @@ -136,6 +136,19 @@ spec: selector: tf-ps: "{param_server_id}" """) +PARAM_LB_SVC = ("""apiVersion: v1 +kind: Service +metadata: + name: tf-ps{param_server_id} + labels: + tf-ps: "{param_server_id}" +spec: + type: LoadBalancer + ports: + - port: {port} + selector: + tf-ps: "{param_server_id}" +""") def main(): @@ -218,8 +231,10 @@ def GenerateConfig(num_workers, num_param_servers, port)) config += '---\n' - config += PARAM_SERVER_SVC.format(port=port, - param_server_id=param_server) + if request_load_balancer: + config += PARAM_LB_SVC.format(port=port, param_server_id=param_server) + else: + config += PARAM_SERVER_SVC.format(port=port, param_server_id=param_server) config += '---\n' return config -- cgit v1.2.3