aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-09-14 19:00:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-14 20:03:21 -0700
commit65b010308c2ab3f365b5b9b40dd56591b179b996 (patch)
treeec6f50c7190b8d33fc1316fd6ad507ac0c949b2a /tensorflow/tools/dist_test
parent57e23cadaed1cfd5245192ee44e8f89713ca01e5 (diff)
Update & fix OSS distributed TF tests: mnist_replica
1) Replace the old and breaking docker-in-docker local test with a single-instance, multi-process test, built upon GitHub PR https://github.com/tensorflow/tensorflow/pull/3935 This simplifies the local test and makes it less susceptible to future changes in docker-in-docker support by docker. 2) Adding --existing_servers flag to mnist_replica.py and associated bash scripts, so that we can distinguish a) the case in which we want to create in-process servers and supervisors (as in the new local_test.sh), and b) the case in which GRPC TF servers are already created and we just want to connect to the workers (as in remote_test.sh). 3) Rename some flags in bash script to improve consistency with the mnist_replica.py. 4) Related doc changes in README.md. Change: 133209130
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/Dockerfile.local53
-rw-r--r--tensorflow/tools/dist_test/README.md31
-rwxr-xr-xtensorflow/tools/dist_test/local_test.sh137
-rw-r--r--tensorflow/tools/dist_test/python/mnist_replica.py289
-rwxr-xr-xtensorflow/tools/dist_test/remote_test.sh24
-rwxr-xr-xtensorflow/tools/dist_test/scripts/create_tf_cluster.sh51
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_mnist_test.sh96
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_test.sh63
-rwxr-xr-xtensorflow/tools/dist_test/scripts/k8s_tensorflow.py19
9 files changed, 390 insertions, 373 deletions
diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local
index e23fa034a3..05da1e92d2 100644
--- a/tensorflow/tools/dist_test/Dockerfile.local
+++ b/tensorflow/tools/dist_test/Dockerfile.local
@@ -1,24 +1,41 @@
-FROM jpetazzo/dind
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Docker image for testing distributed (GRPC) TensorFlow on a single machine.
+#
+# See ./local_test.sh for usage example.
-MAINTAINER Shanqing Cai <cais@google.com>
+FROM ubuntu:16.04
-RUN apt-get update
+MAINTAINER Shanqing Cai <cais@google.com>
-RUN apt-get install -y --no-install-recommends \
- build-essential \
- dbus \
- git \
- software-properties-common
+# Pick up some TF dependencies.
+RUN apt-get update && apt-get install -y \
+ curl \
+ python-numpy \
+ python-pip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
-# Install the latest golang
-RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz
-RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz
-RUN rm -f go1.4.2.linux-amd64.tar.gz
-RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py
-# Create shared storage on host. k8s pods (docker containers) created on the
-# host can share it and all have read/write access.
-RUN mkdir /shared
-RUN chmod 666 /shared
+# Install TensorFlow CPU version from nightly build.
+RUN pip --no-cache-dir install \
+ https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl
-ADD . /var/tf-k8s
+ADD . /var/tf_dist_test
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index b042bcf3a2..91f64dd9c3 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -4,29 +4,18 @@ runtime in TensorFlow.
There are three general modes of testing:
-**1) Launch a local Kubernetes (k8s) cluster and run the test suites on it**
+**1) Launch a docker container and run parameters servers and workers as
+ separate processes therein.**
For example:
./local_test.sh
-This option makes use of the docker-in-docker (dind) containers. It requires
-the docker0 network interface to be set to the promiscuous mode on the host:
-
- sudo ip link set docker0 promisc on
-
-The environment variable "TF_DIST_SERVER_DOCKER_IMAGE" can be used to override
-the Docker image used to generate the TensorFlow GRPC server pods
-("tensorflow/tf_grpc_test_server"). For example:
-
- export TF_DIST_SERVER_DOCKER_IMAGE=<docker_image_name>
- ./local_test.sh
-
By default, local_test.sh runs the MNIST-with-replicas model as a test.
-However, you can use the --model-name flag to run the tf-learn/wide&deep
+However, you can use the --model_name flag to run the tf-learn/wide&deep
cesnsu model:
- ./local_test.sh --model-name CENSUS_WIDENDEEP
+ ./local_test.sh --model_name CENSUS_WIDENDEEP
**2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the
test suite on it**
@@ -36,7 +25,7 @@ For example:
export TF_DIST_GCLOUD_PROJECT="tensorflow-testing"
export TF_DIST_GCLOUD_COMPUTE_ZONE="us-central1-f"
export TF_DIST_CONTAINER_CLUSTER="test-cluster-1"
- export TF_DIST_GCLOUD_KEY_FILE_DIR="/tmp/gcloud-secrets"
+ export TF_DIST_GCLOUD_KEY_FILE="/var/gcloud-secrets/my-gcloud-key.json"
./remote_test.sh
Here you specify the Google Compute Engine (GCE) project, compute zone and
@@ -46,7 +35,7 @@ the JSON service account key file named "tensorflow-testing.json" is located.
You can use the flag "--setup-cluster-only" to perform only the cluster setup
step and skip the testing step:
- ./remote_test.sh --setup-cluster-only
+ ./remote_test.sh --setup_cluster_only
**3) Run the test suite on an existing k8s TensorFlow cluster**
@@ -73,10 +62,10 @@ from the model replicas before the update is applied to the model parameters.
To use this mode, do:
# For remote testing
- ./remote_test.sh --sync-replicas
+ ./remote_test.sh --sync_replicas
# For local testing
- ./local_test.sh --sync-replicas
+ ./local_test.sh --sync_replicas
**Specifying the number of workers**
@@ -85,10 +74,10 @@ You can specify the number of workers by using the --num-workers option flag,
e.g.,
# For remote testing
- ./remote_test.sh --num-workers 4
+ ./remote_test.sh --num_workers 4
# For local testing
- ./local_test.sh --num-workers 4
+ ./local_test.sh --num_workers 4
**Building the GRPC server Docker image**
diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh
index be616b7e24..727258c6d8 100755
--- a/tensorflow/tools/dist_test/local_test.sh
+++ b/tensorflow/tools/dist_test/local_test.sh
@@ -24,33 +24,31 @@
# 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container
# and run the distributed test suite.
#
-# Usage: local_test.sh [--leave-container-running]
-# [--model-name <MODEL_NAME>]
-# [--num-workers <NUM_WORKERS>]
-# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
-# [--sync-replicas]
+# Usage: local_test.sh [--leave_container_running]
+# [--model_name <MODEL_NAME>]
+# [--num_workers <NUM_WORKERS>]
+# [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
+# [--sync_replicas]
#
-# E.g., local_test.sh --model-name CENSUS_WIDENDEEP
-# local_test.sh --num-workers 3 --num-parameter-servers 3
+# E.g., local_test.sh --model_name CENSUS_WIDENDEEP
+# local_test.sh --num_workers 3 --num_parameter_servers 3
#
# Arguments:
-# --leave-container-running: Do not stop the docker-in-docker container after
+# --leave_container_running: Do not stop the docker-in-docker container after
# the termination of the tests, e.g., for debugging
#
-# --num-workers <NUM_WORKERS>:
+# --num_workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
-# --num-parameter-server <NUM_PARAMETER_SERVERS>:
+# --num_parameter_server <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
-# --sync-replicas
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
# In addition, this script obeys the following environment variables:
-# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
-# TensorFlow (GRPC) servers with
# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
@@ -72,20 +70,20 @@ MODEL_NAME=""
MODEL_NAME_FLAG=""
NUM_WORKERS=2
NUM_PARAMETER_SERVERS=2
-SYNC_REPLICAS=0
+SYNC_REPLICAS_FLAG=""
while true; do
- if [[ $1 == "--leave-container-running" ]]; then
+ if [[ $1 == "--leave_container_running" ]]; then
LEAVE_CONTAINER_RUNNING=1
- elif [[ $1 == "--model-name" ]]; then
+ elif [[ $1 == "--model_name" ]]; then
MODEL_NAME="$2"
- MODEL_NAME_FLAG="--model-name ${MODEL_NAME}"
- elif [[ $1 == "--num-workers" ]]; then
+ MODEL_NAME_FLAG="--model_name ${MODEL_NAME}"
+ elif [[ $1 == "--num_workers" ]]; then
NUM_WORKERS=$2
- elif [[ $1 == "--num-parameter-servers" ]]; then
+ elif [[ $1 == "--num_parameter_servers" ]]; then
NUM_PARAMETER_SERVERS=$2
- elif [[ $1 == "--sync-replicas" ]]; then
- SYNC_REPLICAS=1
+ elif [[ $1 == "--sync_replicas" ]]; then
+ SYNC_REPLICAS_FLAG="--sync_replicas"
fi
shift
@@ -98,7 +96,7 @@ echo "LEAVE_CONTAINER_RUNNING: ${LEAVE_CONTAINER_RUNNING}"
echo "MODEL_NAME: \"${MODEL_NAME}\""
echo "NUM_WORKERS: ${NUM_WORKERS}"
echo "NUM_PARAMETER_SERVERS: ${NUM_PARAMETER_SERVERS}"
-echo "SYNC_REPLICAS: \"${SYNC_REPLICAS}\""
+echo "SYNC_REPLICAS_FLAG: \"${SYNC_REPLICAS_FLAG}\""
# Current script directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
@@ -122,90 +120,11 @@ if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
fi
docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \
- -f ${DIR}/Dockerfile.local ${DIR}
-
-
-# Attempt to start the docker container with docker, which will run the k8s
-# cluster inside.
-
-# Get current script directory
-CONTAINER_START_LOG=$(mktemp --suffix=.log)
-echo "Log file for starting cluster container: ${CONTAINER_START_LOG}"
-echo ""
-
-${DIR}/local/start_tf_cluster_container.sh \
- ${LOCAL_K8S_CACHE} \
- ${DOCKER_IMG_NAME} | \
- tee ${CONTAINER_START_LOG} &
-
-# Poll start log until the k8s service is started properly or when maximum
-# attempt count is reached.
-MAX_SERVER_POLLING_ATTEMPTS=600
-
-echo "Waiting for docker-in-docker container for local k8s TensorFlow "\
-"cluster to start and launch Kubernetes..."
-
-COUNTER=0
-while true; do
- sleep 1
-
- ((COUNTER++))
- if [[ "${COUNTER}" -ge "${MAX_SERVER_POLLING_ATTEMPTS}" ]]; then
- die "Reached maximum number of attempts (${MAX_SERVER_POLLING_ATTEMPTS}) "\
-"while waiting for docker-in-docker for local k8s TensorFlow cluster to start"
- fi
-
- # Check for hitting max attempt while trying to start docker-in-docker
- if [[ $(grep -i "Reached maximum number of attempts" \
- "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
- die "Docker-in-docker container for local k8s TensorFlow cluster "\
-"FAILED to start"
- fi
-
- if [[ $(grep -i "Local Kubernetes cluster is running" \
- "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
- break
- fi
-done
-
-# Determine the id of the docker-in-docker container
-DIND_ID=$(get_container_id_by_image_name ${DOCKER_IMG_NAME})
-
-echo "Docker-in-docker container for local k8s TensorFlow cluster has been "\
-"started successfully."
-echo "Docker-in-docker container ID: ${DIND_ID}"
-echo "Launching k8s tf cluster and tests in container ${DIND_ID} ..."
-echo ""
-
-# Launch k8s tf cluster in the docker-in-docker container and perform tests
-SYNC_REPLICAS_FLAG=""
-if [[ ${SYNC_REPLICAS} == "1" ]]; then
- SYNC_REPLICAS_FLAG="--sync-replicas"
-fi
-
-docker exec ${DIND_ID} \
- /var/tf-k8s/local/test_local_tf_cluster.sh \
- ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} \
- ${MODEL_NAME_FLAG} ${SYNC_REPLICAS_FLAG}
-TEST_RES=$?
-
-# Tear down: stop docker-in-docker container
-if [[ ${LEAVE_CONTAINER_RUNNING} == "0" ]]; then
- echo ""
- echo "Stopping docker-in-docker container ${DIND_ID}"
-
- docker stop --time=1 ${DIND_ID} || \
- echo "WARNING: Failed to stop container ${DIND_ID} !!"
-
- echo ""
-else
- echo "Will NOT terminate DIND container ${DIND_ID}"
-fi
-
-if [[ "${TEST_RES}" != "0" ]]; then
- die "Test of distributed TensorFlow runtime on docker-in-docker local "\
-"k8s cluster FAILED"
-else
- echo "Test of distributed TensorFlow runtime on docker-in-docker local "\
-"k8s cluster PASSED"
-fi
+ -f ${DIR}/Dockerfile.local ${DIR} || \
+ die "Failed to build docker image: ${DOCKER_IMG_NAME}"
+
+docker run ${DOCKER_IMG_NAME} \
+ /var/tf_dist_test/scripts/dist_mnist_test.sh \
+ --ps_hosts "localhost:2000,localhost:2001" \
+ --worker_hosts "localhost:3000,localhost:3001" \
+ --num_gpus 0 ${SYNC_REPLICAS_FLAG}
diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py
index 9bd79c8e9c..0f642d5e69 100644
--- a/tensorflow/tools/dist_test/python/mnist_replica.py
+++ b/tensorflow/tools/dist_test/python/mnist_replica.py
@@ -73,9 +73,14 @@ flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workers are aggregated "
"before applied to avoid stale gradients")
+flags.DEFINE_boolean(
+ "existing_servers", False, "Whether servers already exists. If True, "
+ "will use the worker hosts via their GRPC URLs (one client process "
+ "per worker host). Otherwise, will create an in-process TensorFlow "
+ "server.")
flags.DEFINE_string("ps_hosts","localhost:2222",
"Comma-separated list of hostname:port pairs")
-flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
+flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224",
"Comma-separated list of hostname:port pairs")
flags.DEFINE_string("job_name", None,"job name: worker or ps")
@@ -97,156 +102,164 @@ def main(unused_argv):
print("job name = %s" % FLAGS.job_name)
print("task index = %d" % FLAGS.task_index)
-
+
#Construct the cluster and start the server
ps_spec = FLAGS.ps_hosts.split(",")
worker_spec = FLAGS.worker_hosts.split(",")
- # Get the number of workers
+ # Get the number of workers.
num_workers = len(worker_spec)
cluster = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
- server = tf.train.Server(cluster,
- job_name=FLAGS.job_name,
- task_index=FLAGS.task_index)
-
- if FLAGS.job_name == "ps":
- server.join()
- elif FLAGS.job_name == "worker":
- is_chief = (FLAGS.task_index == 0)
- if FLAGS.num_gpus > 0:
- if FLAGS.num_gpus < num_workers:
- raise ValueError("number of gpus is less than number of workers")
- # Avoid gpu allocation conflict: now allocate task_num -> #gpu
- # for each worker in the corresponding machine
- gpu = (FLAGS.task_index % FLAGS.num_gpus)
- worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
- elif FLAGS.num_gpus == 0:
- # Just allocate the CPU to worker server
- cpu = 0
- worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
- # The device setter will automatically place Variables ops on separate
- # parameter servers (ps). The non-Variable ops will be placed on the workers.
- # The ps use CPU and workers use corresponding GPU
- with tf.device(tf.train.replica_device_setter(
- worker_device=worker_device,
- ps_device="/job:ps/cpu:0",
- cluster=cluster)):
- global_step = tf.Variable(0, name="global_step", trainable=False)
-
- # Variables of the hidden layer
- hid_w = tf.Variable(
- tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
- stddev=1.0 / IMAGE_PIXELS), name="hid_w")
- hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
-
- # Variables of the softmax layer
- sm_w = tf.Variable(
- tf.truncated_normal([FLAGS.hidden_units, 10],
- stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
- name="sm_w")
- sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
-
- # Ops: located on the worker specified with FLAGS.task_index
- x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
- y_ = tf.placeholder(tf.float32, [None, 10])
-
- hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
- hid = tf.nn.relu(hid_lin)
-
- y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
- cross_entropy = -tf.reduce_sum(y_ *
- tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
-
- opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
-
- if FLAGS.sync_replicas:
- if FLAGS.replicas_to_aggregate is None:
- replicas_to_aggregate = num_workers
- else:
- replicas_to_aggregate = FLAGS.replicas_to_aggregate
-
- opt = tf.train.SyncReplicasOptimizer(
- opt,
- replicas_to_aggregate=replicas_to_aggregate,
- total_num_replicas=num_workers,
- replica_id=FLAGS.task_index,
- name="mnist_sync_replicas")
-
- train_step = opt.minimize(cross_entropy,
- global_step=global_step)
-
- if FLAGS.sync_replicas and is_chief:
- # Initial token and chief queue runners required by the sync_replicas mode
- chief_queue_runner = opt.get_chief_queue_runner()
- init_tokens_op = opt.get_init_tokens_op()
-
- init_op = tf.initialize_all_variables()
- train_dir = tempfile.mkdtemp()
- sv = tf.train.Supervisor(is_chief=is_chief,
- logdir=train_dir,
- init_op=init_op,
- recovery_wait_secs=1,
- global_step=global_step)
-
- sess_config = tf.ConfigProto(
- allow_soft_placement=True,
- log_device_placement=False,
- device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
-
- # The chief worker (task_index==0) session will prepare the session,
- # while the remaining workers will wait for the preparation to complete.
- if is_chief:
- print("Worker %d: Initializing session..." % FLAGS.task_index)
- else:
- print("Worker %d: Waiting for session to be initialized..." %
- FLAGS.task_index)
+ if not FLAGS.existing_servers:
+ # Not using existing servers. Create an in-process server.
+ server = tf.train.Server(
+ cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)
+ if FLAGS.job_name == "ps":
+ server.join()
+
+ is_chief = (FLAGS.task_index == 0)
+ if FLAGS.num_gpus > 0:
+ if FLAGS.num_gpus < num_workers:
+ raise ValueError("number of gpus is less than number of workers")
+ # Avoid gpu allocation conflict: now allocate task_num -> #gpu
+ # for each worker in the corresponding machine
+ gpu = (FLAGS.task_index % FLAGS.num_gpus)
+ worker_device = "/job:worker/task:%d/gpu:%d" % (FLAGS.task_index, gpu)
+ elif FLAGS.num_gpus == 0:
+ # Just allocate the CPU to worker server
+ cpu = 0
+ worker_device = "/job:worker/task:%d/cpu:%d" % (FLAGS.task_index, cpu)
+ # The device setter will automatically place Variables ops on separate
+ # parameter servers (ps). The non-Variable ops will be placed on the workers.
+ # The ps use CPU and workers use corresponding GPU
+ with tf.device(
+ tf.train.replica_device_setter(
+ worker_device=worker_device,
+ ps_device="/job:ps/cpu:0",
+ cluster=cluster)):
+ global_step = tf.Variable(0, name="global_step", trainable=False)
+
+ # Variables of the hidden layer
+ hid_w = tf.Variable(
+ tf.truncated_normal(
+ [IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
+ stddev=1.0 / IMAGE_PIXELS),
+ name="hid_w")
+ hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
+
+ # Variables of the softmax layer
+ sm_w = tf.Variable(
+ tf.truncated_normal(
+ [FLAGS.hidden_units, 10],
+ stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
+ name="sm_w")
+ sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
+
+ # Ops: located on the worker specified with FLAGS.task_index
+ x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
+ hid = tf.nn.relu(hid_lin)
+
+ y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
+ cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
+
+ opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
+
+ if FLAGS.sync_replicas:
+ if FLAGS.replicas_to_aggregate is None:
+ replicas_to_aggregate = num_workers
+ else:
+ replicas_to_aggregate = FLAGS.replicas_to_aggregate
+
+ opt = tf.train.SyncReplicasOptimizer(
+ opt,
+ replicas_to_aggregate=replicas_to_aggregate,
+ total_num_replicas=num_workers,
+ replica_id=FLAGS.task_index,
+ name="mnist_sync_replicas")
+
+ train_step = opt.minimize(cross_entropy, global_step=global_step)
+
+ if FLAGS.sync_replicas and is_chief:
+ # Initial token and chief queue runners required by the sync_replicas mode
+ chief_queue_runner = opt.get_chief_queue_runner()
+ init_tokens_op = opt.get_init_tokens_op()
+
+ init_op = tf.initialize_all_variables()
+ train_dir = tempfile.mkdtemp()
+ sv = tf.train.Supervisor(
+ is_chief=is_chief,
+ logdir=train_dir,
+ init_op=init_op,
+ recovery_wait_secs=1,
+ global_step=global_step)
+
+ sess_config = tf.ConfigProto(
+ allow_soft_placement=True,
+ log_device_placement=False,
+ device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.task_index])
+
+ # The chief worker (task_index==0) session will prepare the session,
+ # while the remaining workers will wait for the preparation to complete.
+ if is_chief:
+ print("Worker %d: Initializing session..." % FLAGS.task_index)
+ else:
+ print("Worker %d: Waiting for session to be initialized..." %
+ FLAGS.task_index)
+
+ if FLAGS.existing_servers:
+ server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
+ print("Using existing server at: %s" % server_grpc_url)
+
+ sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
+ else:
sess = sv.prepare_or_wait_for_session(server.target,
config=sess_config)
- print("Worker %d: Session initialization complete." % FLAGS.task_index)
-
- if FLAGS.sync_replicas and is_chief:
- # Chief worker will start the chief queue runner and call the init op
- print("Starting chief queue runner and running init_tokens_op")
- sv.start_queue_runners(sess, [chief_queue_runner])
- sess.run(init_tokens_op)
-
- # Perform training
- time_begin = time.time()
- print("Training begins @ %f" % time_begin)
-
- local_step = 0
- while True:
- # Training feed
- batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
- train_feed = {x: batch_xs,
- y_: batch_ys}
-
- _, step = sess.run([train_step, global_step], feed_dict=train_feed)
- local_step += 1
-
- now = time.time()
- print("%f: Worker %d: training step %d done (global step: %d)" %
- (now, FLAGS.task_index, local_step, step))
-
- if step >= FLAGS.train_steps:
- break
-
- time_end = time.time()
- print("Training ends @ %f" % time_end)
- training_time = time_end - time_begin
- print("Training elapsed time: %f s" % training_time)
-
- # Validation feed
- val_feed = {x: mnist.validation.images,
- y_: mnist.validation.labels}
- val_xent = sess.run(cross_entropy, feed_dict=val_feed)
- print("After %d training step(s), validation cross entropy = %g" %
- (FLAGS.train_steps, val_xent))
+ print("Worker %d: Session initialization complete." % FLAGS.task_index)
+
+ if FLAGS.sync_replicas and is_chief:
+ # Chief worker will start the chief queue runner and call the init op
+ print("Starting chief queue runner and running init_tokens_op")
+ sv.start_queue_runners(sess, [chief_queue_runner])
+ sess.run(init_tokens_op)
+
+ # Perform training
+ time_begin = time.time()
+ print("Training begins @ %f" % time_begin)
+
+ local_step = 0
+ while True:
+ # Training feed
+ batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
+ train_feed = {x: batch_xs, y_: batch_ys}
+
+ _, step = sess.run([train_step, global_step], feed_dict=train_feed)
+ local_step += 1
+
+ now = time.time()
+ print("%f: Worker %d: training step %d done (global step: %d)" %
+ (now, FLAGS.task_index, local_step, step))
+
+ if step >= FLAGS.train_steps:
+ break
+
+ time_end = time.time()
+ print("Training ends @ %f" % time_end)
+ training_time = time_end - time_begin
+ print("Training elapsed time: %f s" % training_time)
+
+ # Validation feed
+ val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
+ val_xent = sess.run(cross_entropy, feed_dict=val_feed)
+ print("After %d training step(s), validation cross entropy = %g" %
+ (FLAGS.train_steps, val_xent))
if __name__ == "__main__":
diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh
index b662572418..1d4a52c4c2 100755
--- a/tensorflow/tools/dist_test/remote_test.sh
+++ b/tensorflow/tools/dist_test/remote_test.sh
@@ -20,23 +20,23 @@
# runs from within a container based on the image.
#
# Usage:
-# remote_test.sh [--setup-cluster-only]
-# [--num-workers <NUM_WORKERS>]
-# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
-# [--sync-replicas]
+# remote_test.sh [--setup_cluster_only]
+# [--num_workers <NUM_WORKERS>]
+# [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
+# [--sync_replicas]
#
# Arguments:
-# --setup-cluster-only:
+# --setup_cluster_only:
# Setup the TensorFlow k8s cluster only, and do not perform testing of
# the distributed runtime.
#
-# --num-workers <NUM_WORKERS>:
+# --num_workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
-# --num-parameter-server <NUM_PARAMETER_SERVERS>:
+# --num_parameter_server <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
-# --sync-replicas
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
@@ -56,9 +56,7 @@
# TF_DIST_GRPC_SERVER_URL is empty, same below)
# TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone.
# TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster
-# TF_DIST_GCLOUD_KEY_FILE_DIR: path to the host directory that contains
-# the gloud service key file
-# "tensorflow-testing.json"
+# TF_DIST_GCLOUD_KEY_FILE: path to the gloud service JSON key file
# TF_DIST_GRPC_PORT: port on which to create the TensorFlow GRPC
# servers
# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
@@ -99,9 +97,9 @@ fi
docker build ${NO_CACHE_FLAG} \
-t ${DOCKER_IMG_NAME} -f "${DIR}/Dockerfile" "${DIR}"
-KEY_FILE_DIR=${TF_DIST_GCLOUD_KEY_FILE_DIR:-"${HOME}/gcloud-secrets"}
+KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-"${HOME}/gcloud-secrets/tensorflow-testing.json"}
-docker run --rm -v ${KEY_FILE_DIR}:/var/gcloud/secrets \
+docker run --rm -v ${KEY_FILE}:/var/gcloud/secrets/tensorflow-testing.json \
${DOCKER_ENV_FLAGS} \
${DOCKER_IMG_NAME} \
/var/tf-dist-test/scripts/dist_test.sh $@
diff --git a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
index b0e07588e8..69c459ec8c 100755
--- a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
+++ b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
@@ -167,10 +167,10 @@ fi
"${KUBECTL_BIN}" create -f "${K8S_YAML}"
# Wait for external IP of worker services to become available
-get_tf_worker_external_ip() {
- # Usage: gen_tf_worker_external_ip <WORKER_INDEX>
- # E.g., gen_tf_worker_external_ip 2
- echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker${1}" | \
+get_tf_external_ip() {
+ # Usage: gen_tf_worker_external_ip <JOB_NAME> <TASK_INDEX>
+ # E.g., gen_tf_worker_external_ip ps 2
+ echo $("${KUBECTL_BIN}" get svc | grep "^tf-${1}${2}" | \
awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
}
@@ -187,16 +187,16 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
"of tf-worker0 service to emerge"
fi
- EXTERN_IPS=""
+ WORKER_EXTERN_IPS=""
WORKER_INDEX=0
- N_AVAILABLE_EXTERNAL_IPS=0
+ N_AVAILABLE_WORKER_EXTERNAL_IPS=0
while true; do
- SVC_EXTERN_IP=$(get_tf_worker_external_ip ${WORKER_INDEX})
+ SVC_EXTERN_IP=$(get_tf_external_ip worker ${WORKER_INDEX})
if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
- EXTERN_IPS="${EXTERN_IPS} ${SVC_EXTERN_IP}"
+ WORKER_EXTERN_IPS="${WORKER_EXTERN_IPS} ${SVC_EXTERN_IP}"
- ((N_AVAILABLE_EXTERNAL_IPS++))
+ ((N_AVAILABLE_WORKER_EXTERNAL_IPS++))
fi
((WORKER_INDEX++))
@@ -205,16 +205,42 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
fi
done
- if [[ ${N_AVAILABLE_EXTERNAL_IPS} == ${NUM_WORKERS} ]]; then
+ PS_EXTERN_IPS=""
+ PS_INDEX=0
+ N_AVAILABLE_PS_EXTERNAL_IPS=0
+ while true; do
+ SVC_EXTERN_IP=$(get_tf_external_ip ps ${PS_INDEX})
+
+ if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
+ PS_EXTERN_IPS="${PS_EXTERN_IPS} ${SVC_EXTERN_IP}"
+
+ ((N_AVAILABLE_PS_EXTERNAL_IPS++))
+ fi
+
+ ((PS_INDEX++))
+ if [[ ${PS_INDEX} == ${NUM_PARAMETER_SERVERS} ]]; then
+ break;
+ fi
+ done
+
+ if [[ ${N_AVAILABLE_WORKER_EXTERNAL_IPS} == ${NUM_WORKERS} ]] && \
+ [[ ${N_AVAILABLE_PS_EXTERNAL_IPS} == ${NUM_PARAMETER_SERVERS} ]]; then
break;
fi
done
GRPC_SERVER_URLS=""
- for IP in ${EXTERN_IPS}; do
+ for IP in ${WORKER_EXTERN_IPS}; do
GRPC_SERVER_URLS="${GRPC_SERVER_URLS} grpc://${IP}:${GRPC_PORT}"
done
- echo "GRPC URLs of tf-workers: ${GRPC_SERVER_URLS}"
+
+ GRPC_PS_URLS=""
+ for IP in ${PS_EXTERN_IPS}; do
+ GRPC_PS_URLS="${GRPC_PS_URLS} grpc://${IP}:${GRPC_PORT}"
+ done
+
+ echo "GRPC URLs of tf-worker instances: ${GRPC_SERVER_URLS}"
+ echo "GRPC URLs of tf-ps instances: ${GRPC_PS_URLS}"
else
echo "Waiting for tf pods to be all running..."
@@ -251,3 +277,4 @@ fi
echo "Cluster setup complete."
+echo ""
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
index d95f524486..4f2cab22d9 100755
--- a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
@@ -19,24 +19,28 @@
# grpc pods and service set up.
#
# Usage:
-# dist_mnist_test.sh [--ps-hosts <PS_HOSTS>]
-# [--worker-hosts <WORKER_HOSTS>]
-# [--num-gpus <NUM_GPUS>]
-# [--sync-replicas]
+# dist_mnist_test.sh [--existing_servers (True|False)]
+# [--ps_hosts <PS_HOSTS>]
+# [--worker_hosts <WORKER_HOSTS>]
+# [--num_gpus <NUM_GPUS>]
+# [--sync_replicas]
#
-# --sync-replicas
+# --existing_servers
+# Use TensorFlow GRPC servers that are already created and running.
+#
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
-# ps-hosts/worker-hosts is the list of IP addresses or the GRPC URLs of the ps/worker of
+# ps_hosts/worker_hosts is the list of IP addresses or the GRPC URLs of the ps/worker of
# the worker sessions, separated with ","
# e.g., "localhost:2222,localhost:2223"
#
-# --num-gpus <NUM_GPUS>:
+# --num_gpus <NUM_GPUS>:
# Specifies the number of gpus to use
#
-# NOTES:
+# NOTES:
# If you have the error "$'\r': command not found"
# Please run the command below to remove trailing '\r' character that causes the error:
# sed -i 's/\r$//' dist_mnist_test.sh
@@ -52,25 +56,33 @@ die() {
}
if [[ $# == "0" ]]; then
- die "Usage: $0 [--ps-hosts <PS_HOSTS>] [--worker-hosts <WORKER_HOSTS>] "\
-"[--num-gpus <NUM_GPUS>] [--sync-replicas]"
+ die "Usage: $0 [--ps_hosts <PS_HOSTS>] [--worker_hosts <WORKER_HOSTS>] "\
+"[--num_gpus <NUM_GPUS>] [--sync_replicas]"
fi
# Process additional input arguments
SYNC_REPLICAS=0
+N_GPUS=0
+EXISTING_SERVERS=False
while true; do
- if [[ "$1" == "--ps-hosts" ]]; then
+ if [[ "$1" == "--ps_hosts" ]]; then
PS_HOSTS=$2
- elif [[ "$1" == "--worker-hosts" ]]; then
+ elif [[ "$1" == "--worker_hosts" ]]; then
WORKER_HOSTS=$2
- elif [[ "$1" == "--num-gpus" ]]; then
+ elif [[ "$1" == "--existing_servers" ]]; then
+ EXISTING_SERVERS=$2
+ if [[ "${EXISTING_SERVERS}" != "True" ]] && \
+ [[ "${EXISTING_SERVERS}" != "False" ]]; then
+ die "Invalid value for --existing_servers: should be (True|False)"
+ fi
+ elif [[ "$1" == "--num_gpus" ]]; then
N_GPUS=$2
- elif [[ "$1" == "--sync-replicas" ]]; then
+ elif [[ "$1" == "--sync_replicas" ]]; then
SYNC_REPLICAS="1"
- die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\
+ die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\
"supported by this test yet."
- # TODO(cais): Remove error message once sync-replicas is fully supported
+ # TODO(cais): Remove error message once sync_replicas is fully supported.
fi
shift 2
@@ -86,6 +98,7 @@ else
SYNC_REPLICAS_FLAG="False"
fi
+echo "EXISTING_SERVERS = ${EXISTING_SERVERS}"
echo "PS_HOSTS = ${PS_HOSTS}"
echo "WORKER_HOSTS = ${WORKER_HOSTS}"
echo "NUM_GPUS = ${N_GPUS}"
@@ -105,6 +118,7 @@ PS_LOG_PREFIX="/tmp/ps"
# First, download the data from a single process, to avoid race-condition
# during data downloading
+# Pre-download data files.
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
--ps_hosts="${PS_HOSTS}" \
--worker_hosts="${WORKER_HOSTS}" \
@@ -123,25 +137,30 @@ PS_ARRAY=$(echo ${PS_HOSTS} | awk -F "," '{for(i=1;i<=NF;i++){printf $i" "}}')
# Run a number of ps in parallel. In general, we only set 1 ps.
echo "${N_PS} ps process(es) running in parallel..."
-IDX=0
-PS=($PS_HOSTS)
-while true; do
- timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
- --ps_hosts="${PS_HOSTS}" \
- --worker_hosts="${WORKER_HOSTS}" \
- --job_name="ps" \
- --task_index=${IDX} \
- --num_gpus=${N_GPUS} \
- --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${PS_LOG_PREFIX}${IDX}.log" &
- echo "PS ${IDX}: "
- echo " PS HOST: ${PS_ARRAY[IDX]}"
- echo " log file: ${PS_LOG_PREFIX}${IDX}.log"
-
- ((IDX++))
- if [[ "${IDX}" == "${N_PS}" ]]; then
- break
- fi
-done
+if [[ ${EXISTING_SERVERS} == "False" ]]; then
+ echo "Hello"
+ # Create parameter servers.
+ IDX=0
+ PS=($PS_HOSTS)
+ while true; do
+ python "${MNIST_REPLICA}" \
+ --existing_servers="${EXISTING_SERVERS}" \
+ --ps_hosts="${PS_HOSTS}" \
+ --worker_hosts="${WORKER_HOSTS}" \
+ --job_name="ps" \
+ --task_index=${IDX} \
+ --num_gpus=${N_GPUS} \
+ --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${PS_LOG_PREFIX}${IDX}.log" &
+ echo "PS ${IDX}: "
+ echo " PS HOST: ${PS_ARRAY[IDX]}"
+ echo " log file: ${PS_LOG_PREFIX}${IDX}.log"
+
+ ((IDX++))
+ if [[ "${IDX}" == "${N_PS}" ]]; then
+ break
+ fi
+ done
+fi
# Get N_WORKERS by WORKER_HOSTS
@@ -155,12 +174,14 @@ INDICES=""
IDX=0
while true; do
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
+ --existing_servers="${EXISTING_SERVERS}" \
--ps_hosts="${PS_HOSTS}" \
--worker_hosts="${WORKER_HOSTS}" \
--job_name="worker" \
--task_index=${IDX} \
--num_gpus=${N_GPUS} \
- --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${WKR_LOG_PREFIX}${IDX}.log" &
+ --train_steps=500 \
+ --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${WKR_LOG_PREFIX}${IDX}.log" &
echo "Worker ${IDX}: "
echo " WORKER HOST: ${WORKER_ARRAY[IDX]}"
echo " log file: ${WKR_LOG_PREFIX}${IDX}.log"
@@ -171,9 +192,8 @@ while true; do
if [[ "${IDX}" == "${N_WORKERS}" ]]; then
break
fi
-done
-
+done
# Poll until all final validation cross entropy values become available or
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh
index 1d60aa518f..080ce1df5f 100755
--- a/tensorflow/tools/dist_test/scripts/dist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_test.sh
@@ -25,25 +25,25 @@
# TensorFlow ops.
#
# Usage:
-# dist_test.sh [--setup-cluster-only]
-# [--model-name (MNIST | CENSUS_WIDENDEEP)]
-# [--num-workers <NUM_WORKERS>]
-# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
-# [--sync-replicas]
+# dist_test.sh [--setup_cluster_only]
+# [--model_name (MNIST | CENSUS_WIDENDEEP)]
+# [--num_workers <NUM_WORKERS>]
+# [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
+# [--sync_replicas]
#
-# --setup-cluster-only:
+# --setup_cluster_only:
# Lets the script only set up the k8s container network
#
-# --model-name
+# --model_name
# Name of the model to test. Default is MNIST.
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
-# --num-parameter-server <NUM_PARAMETER_SERVERS>:
+# --num_parameter_servers <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
-# --sync-replicas
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
@@ -72,15 +72,15 @@ SYNC_REPLICAS=0
SETUP_CLUSTER_ONLY=0
while true; do
- if [[ "$1" == "--model-name" ]]; then
+ if [[ "$1" == "--model_name" ]]; then
MODEL_NAME=$2
- elif [[ "$1" == "--num-workers" ]]; then
+ elif [[ "$1" == "--num_workers" ]]; then
NUM_WORKERS=$2
- elif [[ "$1" == "--num-parameter-servers" ]]; then
+ elif [[ "$1" == "--num_parameter_servers" ]]; then
NUM_PARAMETER_SERVERS=$2
- elif [[ "$1" == "--sync-replicas" ]]; then
+ elif [[ "$1" == "--sync_replicas" ]]; then
SYNC_REPLICAS=1
- elif [[ "$1" == "--setup-cluster-only" ]]; then
+ elif [[ "$1" == "--setup_cluster_only" ]]; then
SETUP_CLUSTER_ONLY=1
fi
shift
@@ -132,17 +132,32 @@ else
tee "${TMP}" || \
die "Creation of TensorFlow k8s cluster FAILED"
- GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-workers: .*" | \
- sed -e 's/GRPC URLs of tf-workers://g')
+ GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \
+ sed -e 's/GRPC URLs of tf-worker instances://g')
+
+ GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \
+ sed -e 's/GRPC URLs of tf-ps instances://g')
if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then
die "FAILED to determine GRPC server URLs of all workers"
fi
+ if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then
+ die "FAILED to determine GRPC server URLs of all parameter servers"
+ fi
+
+ WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \
+ sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
+ PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \
+ sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
+
+ echo "WORKER_HOSTS = ${WORKER_HOSTS}"
+ echo "PS_HOSTS = ${PS_HOSTS}"
+
rm -f ${TMP}
if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then
echo "Skipping testing of distributed runtime due to "\
-"option flag --setup-cluster-only"
+"option flag --setup_cluster_only"
exit 0
fi
fi
@@ -158,17 +173,21 @@ test_MNIST() {
return 1
fi
- echo "Performing distributed MNIST training through grpc sessions @ "\
+ echo "Performing distributed MNIST training through worker grpc sessions @ "\
"${GRPC_SERVER_URLS}..."
+ echo "and ps grpc sessions @ ${GRPC_PS_URLS}"
+
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
- SYNC_REPLICAS_FLAG="--sync-replicas"
+ SYNC_REPLICAS_FLAG="--sync_replicas"
fi
- "${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \
- --num-workers "${NUM_WORKERS}" \
- --num-parameter-servers "${NUM_PARAMETER_SERVERS}" \
+ "${MNIST_DIST_TEST_BIN}" \
+ --existing_servers True \
+ --ps_hosts "${PS_HOSTS}" \
+ --worker_hosts "${WORKER_HOSTS}" \
+ --num_gpus 0 \
${SYNC_REPLICAS_FLAG}
if [[ $? == "0" ]]; then
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
index 3a427a1d4e..854c6b832a 100755
--- a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
@@ -136,6 +136,19 @@ spec:
selector:
tf-ps: "{param_server_id}"
""")
+PARAM_LB_SVC = ("""apiVersion: v1
+kind: Service
+metadata:
+ name: tf-ps{param_server_id}
+ labels:
+ tf-ps: "{param_server_id}"
+spec:
+ type: LoadBalancer
+ ports:
+ - port: {port}
+ selector:
+ tf-ps: "{param_server_id}"
+""")
def main():
@@ -218,8 +231,10 @@ def GenerateConfig(num_workers,
num_param_servers,
port))
config += '---\n'
- config += PARAM_SERVER_SVC.format(port=port,
- param_server_id=param_server)
+ if request_load_balancer:
+ config += PARAM_LB_SVC.format(port=port, param_server_id=param_server)
+ else:
+ config += PARAM_SERVER_SVC.format(port=port, param_server_id=param_server)
config += '---\n'
return config