aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-10-27 12:52:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-27 14:07:44 -0700
commit21a7ae05e04f4f060938db08015cb47896970dd1 (patch)
tree3728f4d39a8a3c7150b19d85f8c3a418492efe6b /tensorflow/tools/dist_test
parente692686087722a54f4b48af94cd73a7d57eb56bc (diff)
Upgrade SyncReplicasOptimizer to V2 in dist_test
Also removing some obsolete and unused files in dist_test/local. Change: 137436178
Diffstat (limited to 'tensorflow/tools/dist_test')
-rw-r--r--tensorflow/tools/dist_test/Dockerfile1
-rwxr-xr-xtensorflow/tools/dist_test/local/start_local_k8s_service.sh118
-rwxr-xr-xtensorflow/tools/dist_test/local/start_tf_cluster_container.sh91
-rwxr-xr-xtensorflow/tools/dist_test/local/test_local_tf_cluster.sh142
-rwxr-xr-xtensorflow/tools/dist_test/local_test.sh4
-rw-r--r--tensorflow/tools/dist_test/python/mnist_replica.py47
-rwxr-xr-xtensorflow/tools/dist_test/remote_test.sh5
-rwxr-xr-xtensorflow/tools/dist_test/scripts/create_tf_cluster.sh3
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_mnist_test.sh21
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_test.sh18
10 files changed, 70 insertions, 380 deletions
diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile
index 9888cfd14f..65d7e1717e 100644
--- a/tensorflow/tools/dist_test/Dockerfile
+++ b/tensorflow/tools/dist_test/Dockerfile
@@ -24,6 +24,7 @@ MAINTAINER Shanqing Cai <cais@google.com>
RUN apt-get update
RUN apt-get install -y --no-install-recommends \
+ curl \
python \
python-numpy \
python-pip \
diff --git a/tensorflow/tools/dist_test/local/start_local_k8s_service.sh b/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
deleted file mode 100755
index 6d12ed7b3c..0000000000
--- a/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
+++ /dev/null
@@ -1,118 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Start a Kubernetes (k8s) cluster on the local machine.
-#
-# This script assumes that git, docker, and golang are installed and on
-# the path. It will attempt to install the version of etcd recommended by the
-# kubernetes source.
-#
-# Usage: start_local_k8s_service.sh
-#
-# This script obeys the following environment variables:
-# TF_DIST_K8S_SRC_DIR: Overrides the default directory for k8s source code.
-# TF_DIST_K8S_SRC_BRANCH: Overrides the default branch to run the local k8s
-# cluster with.
-
-
-# Configurations
-K8S_SRC_REPO=https://github.com/kubernetes/kubernetes.git
-K8S_SRC_DIR=${TF_DIST_K8S_SRC_DIR:-/local/kubernetes}
-K8S_SRC_BRANCH=${TF_DIST_K8S_SRC_BRANCH:-release-1.2}
-
-# Helper functions
-die() {
- echo $@
- exit 1
-}
-
-# Start docker service. Try multiple times if necessary.
-COUNTER=0
-while true; do
- ((COUNTER++))
- service docker start
- sleep 1
-
- service docker status
- if [[ $? == "0" ]]; then
- echo "Docker service started successfully."
- break;
- else
- echo "Docker service failed to start"
-
- # 23 is the exit code to signal failure to start docker service in the dind
- # container.
- exit 23
-
- fi
-done
-
-# Wait for docker0 net interface to appear
-echo "Waiting for docker0 network interface to appear..."
-while true; do
- if [[ -z $(netstat -i | grep "^docker0") ]]; then
- sleep 1
- else
- break
- fi
-done
-echo "docker0 interface has appeared."
-
-# Set docker0 to promiscuous mode
-ip link set docker0 promisc on || \
- die "FAILED to set docker0 to promiscuous"
-echo "Turned promisc on for docker0"
-
-# Check promiscuous mode of docker0
-netstat -i
-
-umask 000
-if [[ ! -d "${K8S_SRC_DIR}/.git" ]]; then
- mkdir -p ${K8S_SRC_DIR}
- git clone ${K8S_SRC_REPO} ${K8S_SRC_DIR} || \
- die "FAILED to clone k8s source from GitHub from: ${K8S_SRC_REPO}"
-fi
-
-pushd ${K8S_SRC_DIR}
-git checkout ${K8S_SRC_BRANCH} || \
- die "FAILED to checkout k8s source branch: ${K8S_SRC_BRANCH}"
-git pull origin ${K8S_SRC_BRANCH} || \
- die "FAILED to pull from k8s source branch: ${K8S_SRC_BRANCH}"
-
-# Create kubectl binary
-
-# Install etcd
-hack/install-etcd.sh
-
-export PATH=$(pwd)/third_party/etcd:${PATH}
-
-# Setup golang
-export PATH=/usr/local/go/bin:${PATH}
-
-echo "etcd path: $(which etcd)"
-echo "go path: $(which go)"
-
-# Create shortcut to kubectl
-echo '#!/bin/bash' > /usr/local/bin/kubectl
-echo "$(pwd)/cluster/kubectl.sh \\" >> /usr/local/bin/kubectl
-echo ' $@' >> /usr/local/bin/kubectl
-chmod +x /usr/local/bin/kubectl
-
-# Bring up local cluster
-export KUBE_ENABLE_CLUSTER_DNS=true
-hack/local-up-cluster.sh
-
-popd
diff --git a/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh b/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
deleted file mode 100755
index 49578d3051..0000000000
--- a/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
+++ /dev/null
@@ -1,91 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Starts a docker-in-docker (dind) container that is capable of running docker
-# service and Kubernetes (k8s) cluster inside.
-#
-# Usage: start_tf_cluster_container.sh <local_k8s_dir> <docker_img_name>
-#
-# local_k8s_dir: Kubernetes (k8s) source directory on the host
-# docker_img_name: Name of the docker image to start
-#
-# In addition, this script obeys the following environment variables:
-# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
-# TensorFlow (GRPC) servers with
-
-# Parse input arguments
-if [[ $# != "2" ]]; then
- echo "Usage: $0 <host_k8s_dir> <docker_img_name>"
- exit 1
-fi
-
-HOST_K8S_DIR=$1
-DOCKER_IMG_NAME=$2
-
-# Helper functions
-die() {
- echo $@
- exit 1
-}
-
-# Maximum number of tries to start the docker container with docker running
-# inside
-MAX_ATTEMPTS=100
-
-# Map environment variables into the docker-in-docker (dind) container
-DOCKER_ENV=""
-if [[ ! -z "${TF_DIST_SERVER_DOCKER_IMAGE}" ]]; then
- DOCKER_ENV="-e TF_DIST_SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE}"
-fi
-
-# Verify that the promisc (promiscuous mode) flag is set on docker0 network
-# interface
-if [[ -z $(netstat -i | grep "^docker0" | awk '{print $NF}' | grep -o P) ]];
-then
- die "FAILED: Cannot proceed with dind k8s container creation because "\
-"network interface 'docker0' is not set to promisc on the host."
-fi
-
-# Create cache for k8s source
-if [[ ! -d ${HOST_K8S_DIR} ]]; then
- umask 000
- mkdir -p ${HOST_K8S_DIR} || die "FAILED to create directory for k8s source"
-fi
-
-# Attempt to start docker service in docker container.
-# Try multiple times if necessary.
-COUNTER=1
-while true; do
- ((COUNTER++))
- docker run --rm --net=host --privileged ${DOCKER_ENV} \
- -v ${HOST_K8S_DIR}:/local/kubernetes \
- ${DOCKER_IMG_NAME} \
- /var/tf-k8s/local/start_local_k8s_service.sh
-
- if [[ $? == "23" ]]; then
- if [[ "${COUNTER}" -ge "${MAX_ATTEMPTS}" ]]; then
- echo "Reached maximum number of attempts (${MAX_ATTEMPTS}) "\
-"while attempting to start docker-in-docker for local k8s TensorFlow cluster"
- exit 1
- fi
-
- echo "Docker service failed to start."
- echo "Will make another attempt (#${COUNTER}) to start it..."
- sleep 1
- else
- break
- fi
-done
diff --git a/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh b/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
deleted file mode 100755
index 402f7b5f55..0000000000
--- a/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
+++ /dev/null
@@ -1,142 +0,0 @@
-#!/usr/bin/env bash
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-#
-# Launch a Kubernetes (k8s) TensorFlow cluster on the local machine and run
-# the distributed test suite.
-#
-# This script assumes that a TensorFlow cluster is already running on the
-# local machine and can be controlled by the "kubectl" binary.
-#
-# Usage: test_local_tf_cluster.sh <NUM_WORKERS> <NUM_PARAMETER_SERVERS>
-# [--model-name <MODEL_NAME>]
-# [--sync-replicas]
-#
-# --sync-replicas
-# Use the synchronized-replica mode. The parameter updates from the replicas
-# (workers) will be aggregated before applied, which avoids stale parameter
-# updates.
-
-export GCLOUD_BIN=/usr/local/bin/gcloud
-export TF_DIST_LOCAL_CLUSTER=1
-
-# Parse input arguments
-if [[ $# == 0 ]] || [[ $# == 1 ]]; then
- echo "Usage: $0 <NUM_WORKERS> <NUM_PARAMETER_SERVERS>"
- exit 1
-fi
-
-NUM_WORKERS=$1
-NUM_PARAMETER_SERVERS=$2
-shift
-shift
-
-# Process optional command-line flags
-MODEL_NAME=""
-MODEL_NAME_FLAG=""
-SYNC_REPLICAS_FLAG=""
-while true; do
- if [[ "$1" == "--model-name" ]]; then
- MODEL_NAME="$2"
- MODEL_NAME_FLAG="--model-name ${MODEL_NAME}"
- elif [[ "$1" == "--sync-replicas" ]]; then
- SYNC_REPLICAS_FLAG="--sync-replicas"
- fi
- shift
-
- if [[ -z "$1" ]]; then
- break
- fi
-done
-
-echo "NUM_WORKERS: ${NUM_WORKERS}"
-echo "NUM_PARAMETER_SERVERS: ${NUM_PARAMETER_SERVERS}"
-echo "MODEL_NAME: \"${MODEL_NAME}\""
-echo "MODEL_NAME_FLAG: \"${MODEL_NAME_FLAG}\""
-echo "SYNC_REPLICAS_FLAG: \"${SYNC_REPLICAS_FLAG}\""
-
-# Get current script directory
-DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
-
-# Get utility functions
-source "${DIR}/../scripts/utils.sh"
-
-# Wait for the kube-system pods to be running
-KUBECTL_BIN=$(which kubectl)
-if [[ -z ${KUBECTL_BIN} ]]; then
- die "FAILED to find path to kubectl"
-fi
-
-echo "Waiting for kube-system pods to be all running..."
-echo ""
-
-MAX_ATTEMPTS=360
-COUNTER=0
-while true; do
- sleep 1
- ((COUNTER++))
- if [[ "${COUNTER}" -gt "${MAX_ATTEMPTS}" ]]; then
- die "Reached maximum polling attempts while waiting for all pods in "\
-"kube-system to be running in local k8s TensorFlow cluster"
- fi
-
- if [[ $(are_all_pods_running "${KUBECTL_BIN}" "kube-system") == "1" ]]; then
- break
- fi
-done
-
-# Create the local k8s tf cluster
-${DIR}/../scripts/create_tf_cluster.sh \
- ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} | \
- tee /tmp/tf_cluster.log || \
- die "FAILED to create local tf cluster"
-
-DOCKER_CONTAINER_ID=$(cat /tmp/tf_cluster.log | \
- grep "Docker container ID" |
- awk '{print $NF}')
-if [[ -z "${DOCKER_CONTAINER_ID}" ]]; then
- die "FAILED to determine worker0 Docker container ID"
-fi
-
-WORKER_URLS=""
-IDX=0
-while true; do
- WORKER_URLS="${WORKER_URLS},grpc://tf-worker${IDX}:2222"
-
- ((IDX++))
- if [[ ${IDX} == ${NUM_WORKERS} ]]; then
- break
- fi
-done
-
-echo "Worker URLs: ${WORKER_URLS}"
-
-export TF_DIST_GRPC_SERVER_URLS="${WORKER_URLS}"
-GRPC_ENV="TF_DIST_GRPC_SERVER_URLS=${TF_DIST_GRPC_SERVER_URLS}"
-
-# Command to launch clients from worker0
-CMD="${GRPC_ENV} /var/tf-k8s/scripts/dist_test.sh "\
-"--num-workers ${NUM_WORKERS} "\
-"--num-parameter-servers ${NUM_PARAMETER_SERVERS} "\
-"${MODEL_NAME_FLAG} ${SYNC_REPLICAS_FLAG}"
-
-# Launch clients from worker0
-docker exec ${DOCKER_CONTAINER_ID} /bin/bash -c "${CMD}"
-
-if [[ $? != "0" ]]; then
- die "Test of local k8s TensorFlow cluster FAILED"
-else
- echo "Test of local k8s TensorFlow cluster PASSED"
-fi
diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh
index e46e60dd81..f9f37ff0e1 100755
--- a/tensorflow/tools/dist_test/local_test.sh
+++ b/tensorflow/tools/dist_test/local_test.sh
@@ -56,6 +56,10 @@
# In addition, this script obeys the following environment variables:
# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+die() {
+ echo $@
+ exit 1
+}
# Configurations
DOCKER_IMG_NAME="tensorflow/tf-dist-test-local-cluster"
diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py
index 0f642d5e69..b57cbfc79c 100644
--- a/tensorflow/tools/dist_test/python/mnist_replica.py
+++ b/tensorflow/tools/dist_test/python/mnist_replica.py
@@ -177,28 +177,44 @@ def main(unused_argv):
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
- opt = tf.train.SyncReplicasOptimizer(
+ opt = tf.train.SyncReplicasOptimizerV2(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=num_workers,
- replica_id=FLAGS.task_index,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy, global_step=global_step)
- if FLAGS.sync_replicas and is_chief:
+ if FLAGS.sync_replicas:
+ local_init_op = opt.local_step_init_op
+ if is_chief:
+ local_init_op = opt.chief_init_op
+
+ ready_for_local_init_op = opt.ready_for_local_init_op
+
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
- init_tokens_op = opt.get_init_tokens_op()
+ sync_init_op = opt.get_init_tokens_op()
init_op = tf.initialize_all_variables()
train_dir = tempfile.mkdtemp()
- sv = tf.train.Supervisor(
- is_chief=is_chief,
- logdir=train_dir,
- init_op=init_op,
- recovery_wait_secs=1,
- global_step=global_step)
+
+ if FLAGS.sync_replicas:
+ sv = tf.train.Supervisor(
+ is_chief=is_chief,
+ logdir=train_dir,
+ init_op=init_op,
+ local_init_op=local_init_op,
+ ready_for_local_init_op=ready_for_local_init_op,
+ recovery_wait_secs=1,
+ global_step=global_step)
+ else:
+ sv = tf.train.Supervisor(
+ is_chief=is_chief,
+ logdir=train_dir,
+ init_op=init_op,
+ recovery_wait_secs=1,
+ global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
@@ -217,18 +233,17 @@ def main(unused_argv):
server_grpc_url = "grpc://" + worker_spec[FLAGS.task_index]
print("Using existing server at: %s" % server_grpc_url)
- sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
- else:
- sess = sv.prepare_or_wait_for_session(server.target,
+ sess = sv.prepare_or_wait_for_session(server_grpc_url,
config=sess_config)
+ else:
+ sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.task_index)
if FLAGS.sync_replicas and is_chief:
- # Chief worker will start the chief queue runner and call the init op
- print("Starting chief queue runner and running init_tokens_op")
+ # Chief worker will start the chief queue runner and call the init op.
+ sess.run(sync_init_op)
sv.start_queue_runners(sess, [chief_queue_runner])
- sess.run(init_tokens_op)
# Perform training
time_begin = time.time()
diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh
index b1e6b1e71e..935535312d 100755
--- a/tensorflow/tools/dist_test/remote_test.sh
+++ b/tensorflow/tools/dist_test/remote_test.sh
@@ -66,6 +66,11 @@
# servers
# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+die() {
+ echo $@
+ exit 1
+}
+
DOCKER_IMG_NAME="tensorflow/tf-dist-test-client"
# Get current script directory
diff --git a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
index 69c459ec8c..1da6a540f1 100755
--- a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
+++ b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
@@ -102,6 +102,9 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
# Activate gcloud service account
"${GCLOUD_BIN}" auth activate-service-account --key-file "${GCLOUD_KEY_FILE}"
+ # See: https://github.com/kubernetes/kubernetes/issues/30617
+ "${GCLOUD_BIN}" config set container/use_client_certificate True
+
# Set gcloud project
"${GCLOUD_BIN}" config set project "${GCLOUD_PROJECT}"
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
index 7ebe80db1b..ea4906588d 100755
--- a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
@@ -67,30 +67,37 @@ EXISTING_SERVERS=False
while true; do
if [[ "$1" == "--ps_hosts" ]]; then
- PS_HOSTS=$2
+ PS_HOSTS=$2
+ shift 2
elif [[ "$1" == "--worker_hosts" ]]; then
WORKER_HOSTS=$2
+ shift 2
elif [[ "$1" == "--existing_servers" ]]; then
EXISTING_SERVERS=$2
+ shift 2
if [[ "${EXISTING_SERVERS}" != "True" ]] && \
[[ "${EXISTING_SERVERS}" != "False" ]]; then
die "Invalid value for --existing_servers: should be (True|False)"
fi
elif [[ "$1" == "--num_gpus" ]]; then
N_GPUS=$2
+ shift 2
elif [[ "$1" == "--sync_replicas" ]]; then
SYNC_REPLICAS="1"
- die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\
-"supported by this test yet."
- # TODO(cais): Remove error message once sync_replicas is fully supported.
+ shift 1
fi
- shift 2
if [[ -z "$1" ]]; then
break
fi
done
+if [[ ${SYNC_REPLICAS} == "1" ]] && [[ EXISTING_SERVERS == "1" ]]; then
+ die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\
+"supported under the --existing_servers mode yet."
+ # TODO(cais): Remove error message once sync_replicas is fully supported.
+fi
+
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
SYNC_REPLICAS_FLAG="True"
@@ -150,7 +157,7 @@ if [[ ${EXISTING_SERVERS} == "False" ]]; then
--job_name="ps" \
--task_index=${IDX} \
--num_gpus=${N_GPUS} \
- --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${PS_LOG_PREFIX}${IDX}.log" &
+ --sync_replicas=${SYNC_REPLICAS_FLAG} 2>&1 | tee "${PS_LOG_PREFIX}${IDX}.log" &
echo "PS ${IDX}: "
echo " PS HOST: ${PS_ARRAY[IDX]}"
echo " log file: ${PS_LOG_PREFIX}${IDX}.log"
@@ -181,7 +188,7 @@ while true; do
--task_index=${IDX} \
--num_gpus=${N_GPUS} \
--train_steps=500 \
- --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${WKR_LOG_PREFIX}${IDX}.log" &
+ --sync_replicas=${SYNC_REPLICAS_FLAG} 2>&1 | tee "${WKR_LOG_PREFIX}${IDX}.log" &
echo "Worker ${IDX}: "
echo " WORKER HOST: ${WORKER_ARRAY[IDX]}"
echo " log file: ${WKR_LOG_PREFIX}${IDX}.log"
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh
index 080ce1df5f..5c107fb030 100755
--- a/tensorflow/tools/dist_test/scripts/dist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_test.sh
@@ -191,11 +191,12 @@ test_MNIST() {
${SYNC_REPLICAS_FLAG}
if [[ $? == "0" ]]; then
- echo "MNIST-replica test PASSED\n"
+ echo "MNIST-replica test PASSED"
else
- echo "MNIST-replica test FAILED\n"
+ echo "MNIST-replica test FAILED"
return 1
fi
+ echo ""
}
# Test routine for model "CENSUS_WIDENDEEP"
@@ -231,8 +232,9 @@ if [[ $(type -t "test_${MODEL_NAME}") != "function" ]]; then
fi
# Invoke test routine according to model name
-"test_${MODEL_NAME}" || \
- die "Test of distributed training of model ${MODEL_NAME} FAILED"
+"test_${MODEL_NAME}" && \
+ FAILED=0 || \
+ FAILED=1
# Tear down current k8s TensorFlow cluster
if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then
@@ -242,5 +244,9 @@ if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then
die "Cluster tear-down FAILED"
fi
-echo "SUCCESS: Test of distributed TensorFlow runtime PASSED"
-echo "" \ No newline at end of file
+if [[ "${FAILED}" == 1 ]]; then
+ die "Test of distributed training of model ${MODEL_NAME} FAILED"
+else
+ echo "SUCCESS: Test of distributed TensorFlow runtime PASSED"
+ echo ""
+fi