diff options
Diffstat (limited to 'tensorflow/tools/dist_test/scripts/dist_mnist_test.sh')
-rwxr-xr-x | tensorflow/tools/dist_test/scripts/dist_mnist_test.sh | 96 |
1 files changed, 58 insertions, 38 deletions
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh index d95f524486..4f2cab22d9 100755 --- a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh +++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh @@ -19,24 +19,28 @@ # grpc pods and service set up. # # Usage: -# dist_mnist_test.sh [--ps-hosts <PS_HOSTS>] -# [--worker-hosts <WORKER_HOSTS>] -# [--num-gpus <NUM_GPUS>] -# [--sync-replicas] +# dist_mnist_test.sh [--existing_servers (True|False)] +# [--ps_hosts <PS_HOSTS>] +# [--worker_hosts <WORKER_HOSTS>] +# [--num_gpus <NUM_GPUS>] +# [--sync_replicas] # -# --sync-replicas +# --existing_servers +# Use TensorFlow GRPC servers that are already created and running. +# +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. # -# ps-hosts/worker-hosts is the list of IP addresses or the GRPC URLs of the ps/worker of +# ps_hosts/worker_hosts is the list of IP addresses or the GRPC URLs of the ps/worker of # the worker sessions, separated with "," # e.g., "localhost:2222,localhost:2223" # -# --num-gpus <NUM_GPUS>: +# --num_gpus <NUM_GPUS>: # Specifies the number of gpus to use # -# NOTES: +# NOTES: # If you have the error "$'\r': command not found" # Please run the command below to remove trailing '\r' character that causes the error: # sed -i 's/\r$//' dist_mnist_test.sh @@ -52,25 +56,33 @@ die() { } if [[ $# == "0" ]]; then - die "Usage: $0 [--ps-hosts <PS_HOSTS>] [--worker-hosts <WORKER_HOSTS>] "\ -"[--num-gpus <NUM_GPUS>] [--sync-replicas]" + die "Usage: $0 [--ps_hosts <PS_HOSTS>] [--worker_hosts <WORKER_HOSTS>] "\ +"[--num_gpus <NUM_GPUS>] [--sync_replicas]" fi # Process additional input arguments SYNC_REPLICAS=0 +N_GPUS=0 +EXISTING_SERVERS=False while true; do - if [[ "$1" == "--ps-hosts" ]]; then + if [[ "$1" == "--ps_hosts" ]]; then PS_HOSTS=$2 - elif [[ "$1" == "--worker-hosts" ]]; then + elif [[ "$1" == "--worker_hosts" ]]; then WORKER_HOSTS=$2 - elif [[ "$1" == "--num-gpus" ]]; then + elif [[ "$1" == "--existing_servers" ]]; then + EXISTING_SERVERS=$2 + if [[ "${EXISTING_SERVERS}" != "True" ]] && \ + [[ "${EXISTING_SERVERS}" != "False" ]]; then + die "Invalid value for --existing_servers: should be (True|False)" + fi + elif [[ "$1" == "--num_gpus" ]]; then N_GPUS=$2 - elif [[ "$1" == "--sync-replicas" ]]; then + elif [[ "$1" == "--sync_replicas" ]]; then SYNC_REPLICAS="1" - die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\ + die "ERROR: --sync_replicas (synchronized-replicas) mode is not fully "\ "supported by this test yet." - # TODO(cais): Remove error message once sync-replicas is fully supported + # TODO(cais): Remove error message once sync_replicas is fully supported. fi shift 2 @@ -86,6 +98,7 @@ else SYNC_REPLICAS_FLAG="False" fi +echo "EXISTING_SERVERS = ${EXISTING_SERVERS}" echo "PS_HOSTS = ${PS_HOSTS}" echo "WORKER_HOSTS = ${WORKER_HOSTS}" echo "NUM_GPUS = ${N_GPUS}" @@ -105,6 +118,7 @@ PS_LOG_PREFIX="/tmp/ps" # First, download the data from a single process, to avoid race-condition # during data downloading +# Pre-download data files. timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ --ps_hosts="${PS_HOSTS}" \ --worker_hosts="${WORKER_HOSTS}" \ @@ -123,25 +137,30 @@ PS_ARRAY=$(echo ${PS_HOSTS} | awk -F "," '{for(i=1;i<=NF;i++){printf $i" "}}') # Run a number of ps in parallel. In general, we only set 1 ps. echo "${N_PS} ps process(es) running in parallel..." -IDX=0 -PS=($PS_HOSTS) -while true; do - timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ - --ps_hosts="${PS_HOSTS}" \ - --worker_hosts="${WORKER_HOSTS}" \ - --job_name="ps" \ - --task_index=${IDX} \ - --num_gpus=${N_GPUS} \ - --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${PS_LOG_PREFIX}${IDX}.log" & - echo "PS ${IDX}: " - echo " PS HOST: ${PS_ARRAY[IDX]}" - echo " log file: ${PS_LOG_PREFIX}${IDX}.log" - - ((IDX++)) - if [[ "${IDX}" == "${N_PS}" ]]; then - break - fi -done +if [[ ${EXISTING_SERVERS} == "False" ]]; then + echo "Hello" + # Create parameter servers. + IDX=0 + PS=($PS_HOSTS) + while true; do + python "${MNIST_REPLICA}" \ + --existing_servers="${EXISTING_SERVERS}" \ + --ps_hosts="${PS_HOSTS}" \ + --worker_hosts="${WORKER_HOSTS}" \ + --job_name="ps" \ + --task_index=${IDX} \ + --num_gpus=${N_GPUS} \ + --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${PS_LOG_PREFIX}${IDX}.log" & + echo "PS ${IDX}: " + echo " PS HOST: ${PS_ARRAY[IDX]}" + echo " log file: ${PS_LOG_PREFIX}${IDX}.log" + + ((IDX++)) + if [[ "${IDX}" == "${N_PS}" ]]; then + break + fi + done +fi # Get N_WORKERS by WORKER_HOSTS @@ -155,12 +174,14 @@ INDICES="" IDX=0 while true; do timeout ${TIMEOUT} python "${MNIST_REPLICA}" \ + --existing_servers="${EXISTING_SERVERS}" \ --ps_hosts="${PS_HOSTS}" \ --worker_hosts="${WORKER_HOSTS}" \ --job_name="worker" \ --task_index=${IDX} \ --num_gpus=${N_GPUS} \ - --sync_replicas=${SYNC_REPLICAS_FLAG} \ | tee "${WKR_LOG_PREFIX}${IDX}.log" & + --train_steps=500 \ + --sync_replicas=${SYNC_REPLICAS_FLAG} | tee "${WKR_LOG_PREFIX}${IDX}.log" & echo "Worker ${IDX}: " echo " WORKER HOST: ${WORKER_ARRAY[IDX]}" echo " log file: ${WKR_LOG_PREFIX}${IDX}.log" @@ -171,9 +192,8 @@ while true; do if [[ "${IDX}" == "${N_WORKERS}" ]]; then break fi -done - +done # Poll until all final validation cross entropy values become available or |