aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/dist_test/scripts/dist_mnist_test.sh')
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_mnist_test.sh96
1 files changed, 58 insertions, 38 deletions
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
index d95f524486..4f2cab22d9 100755
--- a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
@@ -19,24 +19,28 @@
# grpc pods and service set up.
#
# Usage:
-# dist_mnist_test.sh [--ps-hosts <PS_HOSTS>]
-# [--worker-hosts <WORKER_HOSTS>]
-# [--num-gpus <NUM_GPUS>]
-# [--sync-replicas]
+# dist_mnist_test.sh [--existing_servers (True|False)]
+# [--ps_hosts <PS_HOSTS>]
+# [--worker_hosts <WORKER_HOSTS>]
+# [--num_gpus <NUM_GPUS>]
+# [--sync_replicas]
#
-# --sync-replicas
+# --existing_servers
+# Use TensorFlow GRPC servers that are already created and running.
+#
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
-# ps-hosts/worker-hosts is the list of IP addresses or the GRPC URLs of the ps/worker of
+# ps_hosts/worker_hosts is the list of IP addresses or the GRPC URLs of the ps/worker of
# the worker sessions, separated with ","
# e.g., "localhost:2222,localhost:2223"
#
-# --num-gpus <NUM_GPUS>:
+# --num_gpus <NUM_GPUS>:
# Specifies the number of gpus to use
#
-# NOTES:
+# NOTES:
# If you have the error "$'\r': command not found"
# Please run the command below to remove trailing '\r' character that causes the error:
# sed -i 's/\r$//' dist_mnist_test.sh
@@ -52,25 +56,33 @@ die() {
}
if [[ $# == "0" ]]; then
- die "Usage: $0 [--ps-hosts <PS_HOSTS>] [--worker-hosts <WORKER_HOSTS>] "\
-"[--num-gpus <NUM_GPUS>] [--sync-replicas]"
+ die "Usage: $0 [--ps_hosts <PS_HOSTS>] [--worker_hosts <WORKER_HOSTS>] "\
+"[--num_gpus <NUM_GPUS>] [--sync_replicas]"
fi
# Process additional input arguments
SYNC_REPLICAS=0
+N_GPUS=0
+EXISTING_SERVERS=False
while true; do
- if [[ "$1" == "--ps-hosts" ]]; then
+ if [[ "$1" == "--ps_hosts" ]]; then
PS_HOSTS=$2
- elif [[ "$1" == "--worker-hosts" ]]; then
+ elif [[ "$1" == "--worker_hosts" ]]; then
WORKER_HOSTS=$2
- elif [[ "$1" == "--num-gpus" ]]; then
+ elif [[ "$1" == "--existing_servers" ]]; then
+ EXISTING_SERVERS=$2
+ if [[ "${EXISTING_SERVERS}" != "True" ]] && \
+ [[ "${EXISTING_SERVERS}" != "False" ]]; then
+ die "Invalid value for --existing_servers: should be (True|False)"
+ fi
+ elif [[ "$1" == "--num_gpus" ]]; then
N_GPUS=$2
- elif [[ "$1" == "--sync-replicas" ]]; then
+ elif [[ "$1" == "--sync_replicas" ]]; then
SYNC_REPLICAS="1"
- die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\
+ die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\
"supported by this test yet."
- # TODO(cais): Remove error message once sync-replicas is fully supported
+ # TODO(cais): Remove error message once sync_replicas is fully supported.
fi
shift 2
@@ -86,6 +98,7 @@ else
SYNC_REPLICAS_FLAG="False"
fi
+echo "EXISTING_SERVERS = ${EXISTING_SERVERS}"
echo "PS_HOSTS = ${PS_HOSTS}"
echo "WORKER_HOSTS = ${WORKER_HOSTS}"
echo "NUM_GPUS = ${N_GPUS}"
@@ -105,6 +118,7 @@ PS_LOG_PREFIX="/tmp/ps"
# First, download the data from a single process, to avoid race-condition
# during data downloading
+# Pre-download data files.
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
--ps_hosts="${PS_HOSTS}" \
--worker_hosts="${WORKER_HOSTS}" \
@@ -123,25 +137,30 @@ PS_ARRAY=$(echo ${PS_HOSTS} | awk -F "," '{for(i=1;i<=NF;i++){printf $i" "}}')
# Run a number of ps in parallel. In general, we only set 1 ps.
echo "${N_PS} ps process(es) running in parallel..."
-IDX=0
-PS=($PS_HOSTS)
-while true; do
- timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
- --ps_hosts="${PS_HOSTS}" \
- --worker_hosts="${WORKER_HOSTS}" \
- --job_name="ps" \
- --task_index=${IDX} \
- --num_gpus=${N_GPUS} \
- --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${PS_LOG_PREFIX}${IDX}.log" &
- echo "PS ${IDX}: "
- echo " PS HOST: ${PS_ARRAY[IDX]}"
- echo " log file: ${PS_LOG_PREFIX}${IDX}.log"
-
- ((IDX++))
- if [[ "${IDX}" == "${N_PS}" ]]; then
- break
- fi
-done
+if [[ ${EXISTING_SERVERS} == "False" ]]; then
+ echo "Hello"
+ # Create parameter servers.
+ IDX=0
+ PS=($PS_HOSTS)
+ while true; do
+ python "${MNIST_REPLICA}" \
+ --existing_servers="${EXISTING_SERVERS}" \
+ --ps_hosts="${PS_HOSTS}" \
+ --worker_hosts="${WORKER_HOSTS}" \
+ --job_name="ps" \
+ --task_index=${IDX} \
+ --num_gpus=${N_GPUS} \
+ --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${PS_LOG_PREFIX}${IDX}.log" &
+ echo "PS ${IDX}: "
+ echo " PS HOST: ${PS_ARRAY[IDX]}"
+ echo " log file: ${PS_LOG_PREFIX}${IDX}.log"
+
+ ((IDX++))
+ if [[ "${IDX}" == "${N_PS}" ]]; then
+ break
+ fi
+ done
+fi
# Get N_WORKERS by WORKER_HOSTS
@@ -155,12 +174,14 @@ INDICES=""
IDX=0
while true; do
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
+ --existing_servers="${EXISTING_SERVERS}" \
--ps_hosts="${PS_HOSTS}" \
--worker_hosts="${WORKER_HOSTS}" \
--job_name="worker" \
--task_index=${IDX} \
--num_gpus=${N_GPUS} \
- --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${WKR_LOG_PREFIX}${IDX}.log" &
+ --train_steps=500 \
+ --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${WKR_LOG_PREFIX}${IDX}.log" &
echo "Worker ${IDX}: "
echo " WORKER HOST: ${WORKER_ARRAY[IDX]}"
echo " log file: ${WKR_LOG_PREFIX}${IDX}.log"
@@ -171,9 +192,8 @@ while true; do
if [[ "${IDX}" == "${N_WORKERS}" ]]; then
break
fi
-done
-
+done
# Poll until all final validation cross entropy values become available or