aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/dist_test/scripts/dist_test.sh
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/dist_test/scripts/dist_test.sh')
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_test.sh63
1 files changed, 41 insertions, 22 deletions
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh
index 1d60aa518f..080ce1df5f 100755
--- a/tensorflow/tools/dist_test/scripts/dist_test.sh
+++ b/tensorflow/tools/dist_test/scripts/dist_test.sh
@@ -25,25 +25,25 @@
# TensorFlow ops.
#
# Usage:
-# dist_test.sh [--setup-cluster-only]
-# [--model-name (MNIST | CENSUS_WIDENDEEP)]
-# [--num-workers <NUM_WORKERS>]
-# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
-# [--sync-replicas]
+# dist_test.sh [--setup_cluster_only]
+# [--model_name (MNIST | CENSUS_WIDENDEEP)]
+# [--num_workers <NUM_WORKERS>]
+# [--num_parameter_servers <NUM_PARAMETER_SERVERS>]
+# [--sync_replicas]
#
-# --setup-cluster-only:
+# --setup_cluster_only:
# Lets the script only set up the k8s container network
#
-# --model-name
+# --model_name
# Name of the model to test. Default is MNIST.
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
-# --num-parameter-server <NUM_PARAMETER_SERVERS>:
+# --num_parameter_servers <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
-# --sync-replicas
+# --sync_replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
@@ -72,15 +72,15 @@ SYNC_REPLICAS=0
SETUP_CLUSTER_ONLY=0
while true; do
- if [[ "$1" == "--model-name" ]]; then
+ if [[ "$1" == "--model_name" ]]; then
MODEL_NAME=$2
- elif [[ "$1" == "--num-workers" ]]; then
+ elif [[ "$1" == "--num_workers" ]]; then
NUM_WORKERS=$2
- elif [[ "$1" == "--num-parameter-servers" ]]; then
+ elif [[ "$1" == "--num_parameter_servers" ]]; then
NUM_PARAMETER_SERVERS=$2
- elif [[ "$1" == "--sync-replicas" ]]; then
+ elif [[ "$1" == "--sync_replicas" ]]; then
SYNC_REPLICAS=1
- elif [[ "$1" == "--setup-cluster-only" ]]; then
+ elif [[ "$1" == "--setup_cluster_only" ]]; then
SETUP_CLUSTER_ONLY=1
fi
shift
@@ -132,17 +132,32 @@ else
tee "${TMP}" || \
die "Creation of TensorFlow k8s cluster FAILED"
- GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-workers: .*" | \
- sed -e 's/GRPC URLs of tf-workers://g')
+ GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-worker instances: .*" | \
+ sed -e 's/GRPC URLs of tf-worker instances://g')
+
+ GRPC_PS_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-ps instances: .*" | \
+ sed -e 's/GRPC URLs of tf-ps instances://g')
if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then
die "FAILED to determine GRPC server URLs of all workers"
fi
+ if [[ $(echo ${GRPC_PS_URLS} | wc -w) != ${NUM_PARAMETER_SERVERS} ]]; then
+ die "FAILED to determine GRPC server URLs of all parameter servers"
+ fi
+
+ WORKER_HOSTS=$(echo "${GRPC_SERVER_URLS}" | sed -e 's/^[[:space:]]*//' | \
+ sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
+ PS_HOSTS=$(echo "${GRPC_PS_URLS}" | sed -e 's/^[[:space:]]*//' | \
+ sed -e 's/grpc:\/\///g' | sed -e 's/ /,/g')
+
+ echo "WORKER_HOSTS = ${WORKER_HOSTS}"
+ echo "PS_HOSTS = ${PS_HOSTS}"
+
rm -f ${TMP}
if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then
echo "Skipping testing of distributed runtime due to "\
-"option flag --setup-cluster-only"
+"option flag --setup_cluster_only"
exit 0
fi
fi
@@ -158,17 +173,21 @@ test_MNIST() {
return 1
fi
- echo "Performing distributed MNIST training through grpc sessions @ "\
+ echo "Performing distributed MNIST training through worker grpc sessions @ "\
"${GRPC_SERVER_URLS}..."
+ echo "and ps grpc sessions @ ${GRPC_PS_URLS}"
+
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
- SYNC_REPLICAS_FLAG="--sync-replicas"
+ SYNC_REPLICAS_FLAG="--sync_replicas"
fi
- "${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \
- --num-workers "${NUM_WORKERS}" \
- --num-parameter-servers "${NUM_PARAMETER_SERVERS}" \
+ "${MNIST_DIST_TEST_BIN}" \
+ --existing_servers True \
+ --ps_hosts "${PS_HOSTS}" \
+ --worker_hosts "${WORKER_HOSTS}" \
+ --num_gpus 0 \
${SYNC_REPLICAS_FLAG}
if [[ $? == "0" ]]; then