diff options
Diffstat (limited to 'tensorflow/tools/dist_test/scripts/dist_test.sh')
-rwxr-xr-x | tensorflow/tools/dist_test/scripts/dist_test.sh | 63 |
1 files changed, 41 insertions, 22 deletions
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh index 1d60aa518f..080ce1df5f 100755 --- a/tensorflow/tools/dist_test/scripts/dist_test.sh +++ b/tensorflow/tools/dist_test/scripts/dist_test.sh @@ -25,25 +25,25 @@ # TensorFlow ops. # # Usage: -# dist_test.sh [--setup-cluster-only] -# [--model-name (MNIST | CENSUS_WIDENDEEP)] -# [--num-workers <NUM_WORKERS>] -# [--num-parameter-servers <NUM_PARAMETER_SERVERS>] -# [--sync-replicas] +# dist_test.sh [--setup_cluster_only] +# [--model_name (MNIST | CENSUS_WIDENDEEP)] +# [--num_workers <NUM_WORKERS>] +# [--num_parameter_servers <NUM_PARAMETER_SERVERS>] +# [--sync_replicas] # -# --setup-cluster-only: +# --setup_cluster_only: # Lets the script only set up the k8s container network # -# --model-name +# --model_name # Name of the model to test. Default is MNIST. # # --num-workers <NUM_WORKERS>: # Specifies the number of worker pods to start # -# --num-parameter-server <NUM_PARAMETER_SERVERS>: +# --num_parameter_servers <NUM_PARAMETER_SERVERS>: # Specifies the number of parameter servers to start # -# --sync-replicas +# --sync_replicas # Use the synchronized-replica mode. The parameter updates from the replicas # (workers) will be aggregated before applied, which avoids stale parameter # updates. @@ -72,15 +72,15 @@ SYNC_REPLICAS=0 SETUP_CLUSTER_ONLY=0 while true; do - if [[ "$1" == "--model-name" ]]; then + if [[ "$1" == "--model_name" ]]; then MODEL_NAME=$2 - elif [[ "$1" == "--num-workers" ]]; then + elif [[ "$1" == "--num_workers" ]]; then NUM_WORKERS=$2 - elif [[ "$1" == "--num-parameter-servers" ]]; then + elif [[ "$1" == "--num_parameter_servers" ]]; then NUM_PARAMETER_SERVERS=$2 - elif [[ "$1" == "--sync-replicas" ]]; then + elif [[ "$1" == "--sync_replicas" ]]; then SYNC_REPLICAS=1 - elif [[ "$1" == "--setup-cluster-only" ]]; then + elif [[ "$1" == "--setup_cluster_only" ]]; then SETUP_CLUSTER_ONLY=1 fi shift @@ -132,17 +132,32 @@ else tee "${TMP}" || \ die "Creation of TensorFlow k8s cluster FAILED" - GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-workers: .*" | \ - sed -e 's/GRPC URLs of tf-workers://g') + GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \ + sed -e 's/GRPC URLs of tf-worker instances://g') + + GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \ + sed -e 's/GRPC URLs of tf-ps instances://g') if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then die "FAILED to determine GRPC server URLs of all workers" fi + if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then + die "FAILED to determine GRPC server URLs of all parameter servers" + fi + + WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \ + sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') + PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \ + sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g') + + echo "WORKER_HOSTS = ${WORKER_HOSTS}" + echo "PS_HOSTS = ${PS_HOSTS}" + rm -f ${TMP} if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then echo "Skipping testing of distributed runtime due to "\ -"option flag --setup-cluster-only" +"option flag --setup_cluster_only" exit 0 fi fi @@ -158,17 +173,21 @@ test_MNIST() { return 1 fi - echo "Performing distributed MNIST training through grpc sessions @ "\ + echo "Performing distributed MNIST training through worker grpc sessions @ "\ "${GRPC_SERVER_URLS}..." + echo "and ps grpc sessions @ ${GRPC_PS_URLS}" + SYNC_REPLICAS_FLAG="" if [[ ${SYNC_REPLICAS} == "1" ]]; then - SYNC_REPLICAS_FLAG="--sync-replicas" + SYNC_REPLICAS_FLAG="--sync_replicas" fi - "${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \ - --num-workers "${NUM_WORKERS}" \ - --num-parameter-servers "${NUM_PARAMETER_SERVERS}" \ + "${MNIST_DIST_TEST_BIN}" \ + --existing_servers True \ + --ps_hosts "${PS_HOSTS}" \ + --worker_hosts "${WORKER_HOSTS}" \ + --num_gpus 0 \ ${SYNC_REPLICAS_FLAG} if [[ $? == "0" ]]; then |