aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-03-18 05:59:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-03-18 08:48:33 -0700
commit9742a2ed4e1e2015334d53dc2824c82812ad21e8 (patch)
tree41966c3ce0b15e7b0f2dc19ddfdd6c71930780a3
parentaaca1ccefab6307a0948c782440016389bf994f1 (diff)
Test for distributed (grpc) runtime in OSS TensorFlow
See README.md for detailed descriptions of the usage of the tools and tests in this changeset. Three modes of testing are supported: 1) Launch a local Kubernetes (k8s) cluster and run the test suites on it (See local_test.sh) 2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the test suite on it (See remote_test.sh) 3) Run the test suite on an existing k8s TensorFlow cluster (Also see remote_test.sh) Take the remote test for example, the following steps are performed: 1) Builds a Docker image with gcloud and Kubernetes tools, and the latest TensorFlow pip installed (see Dockerfile) 2) Launches a Docker container based on the said image (see test_distributed.sh) 3) From within the image, authenticate the gcloud user (with credentials files mapped from outside the container), configer the k8s cluster and launch a new k8s container cluster for TensorFlow workers 4) Generate a k8s (yaml) config file and user this yaml file to create a TensorFlow worker cluster consisting of a certian number of parameter servers (ps) and workers. The workers are exposed as external services with public IPs (see dist_test.sh) 5) Run a simple softmax MNIST model on multiple workers, with the model weights and biases located on the ps nodes. Train the models in parallel and observe the final validation cross entropy (see dist_mnist_test.sh) Change: 117543657
-rw-r--r--tensorflow/tools/dist_test/Dockerfile28
-rw-r--r--tensorflow/tools/dist_test/Dockerfile.local20
-rw-r--r--tensorflow/tools/dist_test/README.md76
-rwxr-xr-xtensorflow/tools/dist_test/build_server.sh44
-rw-r--r--tensorflow/tools/dist_test/local/Dockerfile20
-rwxr-xr-xtensorflow/tools/dist_test/local/start_local_k8s_service.sh118
-rwxr-xr-xtensorflow/tools/dist_test/local/start_tf_cluster_container.sh91
-rwxr-xr-xtensorflow/tools/dist_test/local/test_local_tf_cluster.sh88
-rwxr-xr-xtensorflow/tools/dist_test/local_test.sh152
-rwxr-xr-xtensorflow/tools/dist_test/python/mnist_replica.py144
-rwxr-xr-xtensorflow/tools/dist_test/remote_test.sh92
-rwxr-xr-xtensorflow/tools/dist_test/scripts/create_tf_cluster.sh231
-rwxr-xr-xtensorflow/tools/dist_test/scripts/delete_tf_cluster.sh87
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_mnist_test.sh137
-rwxr-xr-xtensorflow/tools/dist_test/scripts/dist_test.sh118
-rwxr-xr-xtensorflow/tools/dist_test/scripts/k8s_tensorflow.py245
-rw-r--r--tensorflow/tools/dist_test/scripts/utils.sh56
-rw-r--r--tensorflow/tools/dist_test/server/Dockerfile59
-rwxr-xr-xtensorflow/tools/dist_test/server/grpc_tensorflow_server.py122
19 files changed, 1928 insertions, 0 deletions
diff --git a/tensorflow/tools/dist_test/Dockerfile b/tensorflow/tools/dist_test/Dockerfile
new file mode 100644
index 0000000000..fba23af55d
--- /dev/null
+++ b/tensorflow/tools/dist_test/Dockerfile
@@ -0,0 +1,28 @@
+FROM ubuntu:14.04
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+RUN apt-get install -y \
+ bc \
+ curl \
+ python \
+ python-numpy \
+ python-pip
+
+# Install Google Cloud SDK
+RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/install_google_cloud_sdk.bash
+RUN chmod +x install_google_cloud_sdk.bash
+RUN ./install_google_cloud_sdk.bash --disable-prompts --install-dir=/var/gcloud
+
+# Install kubectl
+RUN /var/gcloud/google-cloud-sdk/bin/gcloud components install kubectl
+
+# Install nightly TensorFlow pip
+# TODO(cais): Should we build it locally instead?
+RUN pip install \
+ http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
+
+# Copy test files
+COPY scripts /var/tf-dist-test/scripts
+COPY python /var/tf-dist-test/python
diff --git a/tensorflow/tools/dist_test/Dockerfile.local b/tensorflow/tools/dist_test/Dockerfile.local
new file mode 100644
index 0000000000..4d82904707
--- /dev/null
+++ b/tensorflow/tools/dist_test/Dockerfile.local
@@ -0,0 +1,20 @@
+FROM jpetazzo/dind
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+
+RUN apt-get install -y \
+ bc \
+ build-essential \
+ dbus \
+ git \
+ software-properties-common
+
+# Install the latest golang
+RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz
+RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz
+RUN rm -f go1.4.2.linux-amd64.tar.gz
+RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc
+
+ADD . /var/tf-k8s
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
new file mode 100644
index 0000000000..d986900bd6
--- /dev/null
+++ b/tensorflow/tools/dist_test/README.md
@@ -0,0 +1,76 @@
+# Testing Distributed Runtime in TensorFlow
+This folder containers tools and test suites for the GRPC-based distributed
+runtime in TensorFlow.
+
+There are three general modes of testing:
+
+**1) Launch a local Kubernetes (k8s) cluster and run the test suites on it**
+
+For example:
+
+ ./local_test.sh
+
+This option makes use of the docker-in-docker (dind) containers. It requires
+the docker0 network interface to be set to the promiscuous mode on the host:
+
+ sudo ip link set docker0 promisc on
+
+The environment variable "TF_DIST_SERVER_DOCKER_IMAGE" can be used to override
+the Docker image used to generate the TensorFlow GRPC server pods
+("tensorflow/tf_grpc_test_server"). For example:
+
+ export TF_DIST_SERVER_DOCKER_IMAGE=<docker_image_name>
+ ./local_test.sh
+
+**2) Launch a remote k8s cluster on Google Container Engine (GKE) and run the
+test suite on it**
+
+For example:
+
+ export TF_DIST_GCLOUD_PROJECT="tensorflow-testing"
+ export TF_DIST_GCLOUD_COMPUTE_ZONE="us-central1-f"
+ export CONTAINER_CLUSTER="test-cluster-1"
+ export TF_DIST_GCLOUD_KEY_FILE_DIR="/tmp/gcloud-secrets"
+ ./remote_test.sh
+
+Here you specify the Google Compute Engine (GCE) project, compute zone and
+container cluster with the first three environment variables, in that order.
+The environment variable "TF_DIST_GCLOUD_KEY_FILE_DIR" is a directory in which
+the JSON service account key file named "tensorflow-testing.json" is located.
+You can use the flag "--setup-cluster-only" to perform only the cluster setup
+step and skip the testing step:
+
+ ./remote_test.sh --setup-cluster-only
+
+**3) Run the test suite on an existing k8s TensorFlow cluster**
+
+For example:
+
+ export TF_DIST_GRPC_SERVER_URL="grpc://11.22.33.44:2222"
+ ./remote_test.sh
+
+The IP address above is a dummy example. Such a cluster may have been set up
+using the command described at the end of the previous section.
+
+
+**Building the test server Docker image**
+
+To build the Docker image for a test server of TensorFlow distributed runtime,
+run:
+
+ ./build_server.sh <docker_image_name>
+
+
+**Generating configuration file for TensorFlow k8s clusters**
+
+The script at "scripts/k8s_tensorflow.py" can be used to generate yaml
+configuration files for a TensorFlow k8s cluster consisting of a number of
+workers and parameter servers. For example:
+
+ scripts/k8s_tensorflow.py \
+ --num_workers 2 \
+ --num_parameter_servers 2 \
+ --grpc_port 2222 \
+ --request_load_balancer \
+ --docker_image "tensorflow/tf_grpc_test_server" \
+ > tf-k8s-with-lb.yaml
diff --git a/tensorflow/tools/dist_test/build_server.sh b/tensorflow/tools/dist_test/build_server.sh
new file mode 100755
index 0000000000..8679bde2dc
--- /dev/null
+++ b/tensorflow/tools/dist_test/build_server.sh
@@ -0,0 +1,44 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Builds the test server for distributed (GRPC) TensorFlow
+#
+# Usage: build_server.sh <docker_image_name>
+#
+# Note that the Dockerfile is located in ./server/ but the docker build should
+# use the current directory as the context.
+
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Check arguments
+if [[ $# != 1 ]]; then
+ die "Usage: $0 <docker_image_name>"
+fi
+
+DOCKER_IMG_NAME=$1
+
+# Current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Call docker build
+docker build --no-cache -t "${DOCKER_IMG_NAME}" \
+ -f "${DIR}/server/Dockerfile" \
+ "${DIR}"
diff --git a/tensorflow/tools/dist_test/local/Dockerfile b/tensorflow/tools/dist_test/local/Dockerfile
new file mode 100644
index 0000000000..dece508c0d
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/Dockerfile
@@ -0,0 +1,20 @@
+FROM jpetazzo/dind
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+
+RUN apt-get install -y \
+ build-essential \
+ git \
+ software-properties-common
+
+# Install the latest golang
+RUN wget https://storage.googleapis.com/golang/go1.4.2.linux-amd64.tar.gz
+RUN tar -C /usr/local -xzf go1.4.2.linux-amd64.tar.gz
+RUN rm -f go1.4.2.linux-amd64.tar.gz
+RUN echo 'PATH=/usr/local/go/bin:${PATH}' >> /root/.bashrc
+
+ADD start_local_k8s_cluster.sh /var/k8s/start_local_k8s_cluster.sh
+ADD ../scripts /var/k8s/dist_test/scripts
+ADD ../python /var/k8s/dist_test/python
diff --git a/tensorflow/tools/dist_test/local/start_local_k8s_service.sh b/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
new file mode 100755
index 0000000000..51f4805ee8
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/start_local_k8s_service.sh
@@ -0,0 +1,118 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Start a Kubernetes (k8s) cluster on the local machine.
+#
+# This script assumes that git, docker, and golang are installed and on
+# the path. It will attempt to install the version of etcd recommended by the
+# kubernetes source.
+#
+# Usage: start_local_k8s_service.sh
+#
+# This script obeys the following environment variables:
+# TF_DIST_K8S_SRC_DIR: Overrides the default directory for k8s source code.
+# TF_DIST_K8S_SRC_BRANCH: Overrides the default branch to run the local k8s
+# cluster with.
+
+
+# Configurations
+K8S_SRC_REPO=https://github.com/kubernetes/kubernetes.git
+K8S_SRC_DIR=${TF_DIST_K8S_SRC_DIR:-/local/kubernetes}
+K8S_SRC_BRANCH=${TF_DIST_K8S_SRC_BRANCH:-release-1.2}
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Start docker service. Try multiple times if necessary.
+COUNTER=0
+while true; do
+ ((COUNTER++))
+ service docker start
+ sleep 1
+
+ service docker status
+ if [[ $? == "0" ]]; then
+ echo "Docker service started successfully."
+ break;
+ else
+ echo "Docker service failed to start"
+
+ # 23 is the exit code to signal failure to start docker service in the dind
+ # container.
+ exit 23
+
+ fi
+done
+
+# Wait for docker0 net interface to appear
+echo "Waiting for docker0 network interface to appear..."
+while true; do
+ if [[ -z $(netstat -i | grep "^docker0") ]]; then
+ sleep 1
+ else
+ break
+ fi
+done
+echo "docker0 interface has appeared."
+
+# Set docker0 to promiscuous mode
+ip link set docker0 promisc on || \
+ die "FAILED to set docker0 to promiscuous"
+echo "Turned promisc on for docker0"
+
+# Check promiscuous mode of docker0
+netstat -i
+
+umask 000
+if [[ ! -d "${K8S_SRC_DIR}/.git" ]]; then
+ mkdir -p ${K8S_SRC_DIR}
+ git clone ${K8S_SRC_REPO} ${K8S_SRC_DIR} || \
+ die "FAILED to clone k8s source from GitHub from: ${K8S_SRC_REPO}"
+fi
+
+pushd ${K8S_SRC_DIR}
+git checkout ${K8S_SRC_BRANCH} || \
+ die "FAILED to checkout k8s source branch: ${K8S_SRC_BRANCH}"
+git pull origin ${K8S_SRC_BRANCH} || \
+ die "FAILED to pull from k8s source branch: ${K8S_SRC_BRANCH}"
+
+# Create kubectl binary
+
+# Install etcd
+hack/install-etcd.sh
+
+export PATH=$(pwd)/third_party/etcd:${PATH}
+
+# Setup golang
+export PATH=/usr/local/go/bin:${PATH}
+
+echo "etcd path: $(which etcd)"
+echo "go path: $(which go)"
+
+# Create shortcut to kubectl
+echo '#!/bin/bash' > /usr/local/bin/kubectl
+echo "$(pwd)/cluster/kubectl.sh \\" >> /usr/local/bin/kubectl
+echo ' $@' >> /usr/local/bin/kubectl
+chmod +x /usr/local/bin/kubectl
+
+# Bring up local cluster
+export KUBE_ENABLE_CLUSTER_DNS=true
+hack/local-up-cluster.sh
+
+popd
diff --git a/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh b/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
new file mode 100755
index 0000000000..b8448624ef
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/start_tf_cluster_container.sh
@@ -0,0 +1,91 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Starts a docker-in-docker (dind) container that is capable of running docker
+# service and Kubernetes (k8s) cluster inside.
+#
+# Usage: start_tf_cluster_container.sh <local_k8s_dir> <docker_img_name>
+#
+# local_k8s_dir: Kubernetes (k8s) source directory on the host
+# docker_img_name: Name of the docker image to start
+#
+# In addition, this script obeys the following environment variables:
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+
+# Parse input arguments
+if [[ $# != "2" ]]; then
+ echo "Usage: $0 <host_k8s_dir> <docker_img_name>"
+ exit 1
+fi
+
+HOST_K8S_DIR=$1
+DOCKER_IMG_NAME=$2
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Maximum number of tries to start the docker container with docker running
+# inside
+MAX_ATTEMPTS=100
+
+# Map environment variables into the docker-in-docker (dind) container
+DOCKER_ENV=""
+if [[ ! -z "${TF_DIST_SERVER_DOCKER_IMAGE}" ]]; then
+ DOCKER_ENV="-e TF_DIST_SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE}"
+fi
+
+# Verify that the promisc (promiscuous mode) flag is set on docker0 network
+# interface
+if [[ -z $(netstat -i | grep "^docker0" | awk '{print $NF}' | grep -o P) ]];
+then
+ die "FAILED: Cannot proceed with dind k8s container creation because "\
+"network interface 'docker0' is not set to promisc on the host."
+fi
+
+# Create cache for k8s source
+if [[ ! -d ${HOST_K8S_DIR} ]]; then
+ umask 000
+ mkdir -p ${HOST_K8S_DIR} || die "FAILED to create directory for k8s source"
+fi
+
+# Attempt to start docker service in docker container.
+# Try multiple times if necessary.
+COUNTER=1
+while true; do
+ ((COUNTER++))
+ docker run --net=host --privileged ${DOCKER_ENV} \
+ -v ${HOST_K8S_DIR}:/local/kubernetes \
+ ${DOCKER_IMG_NAME} \
+ /var/tf-k8s/local/start_local_k8s_service.sh
+
+ if [[ $? == "23" ]]; then
+ if [[ $(echo "${COUNTER}>=${MAX_ATTEMPTS}" | bc -l) == "1" ]]; then
+ echo "Reached maximum number of attempts (${MAX_ATTEMPTS}) "\
+"while attempting to start docker-in-docker for local k8s TensorFlow cluster"
+ exit 1
+ fi
+
+ echo "Docker service failed to start."
+ echo "Will make another attempt (#${COUNTER}) to start it..."
+ sleep 1
+ else
+ break
+ fi
+done
diff --git a/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh b/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
new file mode 100755
index 0000000000..895a2fe24c
--- /dev/null
+++ b/tensorflow/tools/dist_test/local/test_local_tf_cluster.sh
@@ -0,0 +1,88 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Launch a Kubernetes (k8s) TensorFlow cluster on the local machine and run
+# the distributed test suite.
+#
+# This script assumes that a TensorFlow cluster is already running on the
+# local machine and can be controlled by the "kubectl" binary.
+#
+# Usage: test_local_tf_cluster.sh
+#
+
+export GCLOUD_BIN=/usr/local/bin/gcloud
+export TF_DIST_LOCAL_CLUSTER=1
+
+# TODO(cais): Do not hard-code the numbers of workers and ps
+NUM_WORKERS=2
+NUM_PARAMETER_SERVERS=2
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source "${DIR}/../scripts/utils.sh"
+
+# Wait for the kube-system pods to be running
+KUBECTL_BIN=$(which kubectl)
+if [[ -z ${KUBECTL_BIN} ]]; then
+ die "FAILED to find path to kubectl"
+fi
+
+echo "Waiting for kube-system pods to be all running..."
+echo ""
+
+MAX_ATTEMPTS=360
+COUNTER=0
+while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${MAX_ATTEMPTS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling attempts while waiting for all pods in "\
+"kube-system to be running in local k8s TensorFlow cluster"
+ fi
+
+ if [[ $(are_all_pods_running "${KUBECTL_BIN}" "kube-system") == "1" ]]; then
+ break
+ fi
+done
+
+# Create the local k8s tf cluster
+${DIR}/../scripts/create_tf_cluster.sh \
+ ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} | \
+ tee /tmp/tf_cluster.log || \
+ die "FAILED to create local tf cluster"
+
+DOCKER_CONTAINER_ID=$(cat /tmp/tf_cluster.log | \
+ grep "Docker container ID" |
+ awk '{print $NF}')
+if [[ -z "${DOCKER_CONTAINER_ID}" ]]; then
+ die "FAILED to determine worker0 Docker container ID"
+fi
+
+export TF_DIST_GRPC_SERVER_URL="grpc://tf-worker0:2222"
+GRPC_ENV="TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
+
+docker exec \
+ ${DOCKER_CONTAINER_ID} \
+ /bin/bash -c \
+ "${GRPC_ENV} /var/tf-k8s/scripts/dist_test.sh"
+
+if [[ $? != "0" ]]; then
+ die "Test of local k8s TensorFlow cluster FAILED"
+else
+ echo "Test of local k8s TensorFlow cluster PASSED"
+fi
diff --git a/tensorflow/tools/dist_test/local_test.sh b/tensorflow/tools/dist_test/local_test.sh
new file mode 100755
index 0000000000..d47324cbc3
--- /dev/null
+++ b/tensorflow/tools/dist_test/local_test.sh
@@ -0,0 +1,152 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Tests distributed TensorFlow on a locally running TF GRPC cluster.
+#
+# This script peforms the following steps:
+# 1) Build the docker-in-docker (dind) image capable of running docker and
+# Kubernetes (k8s) cluster inside.
+# 2) Run a container from the aforementioned image and start docker service
+# in it
+# 3) Call a script to launch a k8s TensorFlow GRPC cluster inside the container
+# and run the distributed test suite.
+#
+# Usage: local_test.sh [--leave-container-running]
+#
+# Arguments:
+# --leave-container-running: Do not stop the docker-in-docker container after
+# the termination of the tests, e.g., for debugging
+#
+# In addition, this script obeys the following environment variables:
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+
+
+# Configurations
+DOCKER_IMG_NAME="tensorflow/tf-dist-test-local-cluster"
+LOCAL_K8S_CACHE=${HOME}/kubernetes
+
+# Helper function
+get_container_id_by_image_name() {
+ # Get the id of a container by image name
+ # Usage: get_docker_container_id_by_image_name <img_name>
+
+ echo $(docker ps | grep $1 | awk '{print $1}')
+}
+
+# Current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source ${DIR}/scripts/utils.sh
+
+
+# First, make sure that no docker-in-docker container of the same image
+# is already running
+if [[ ! -z $(get_container_id_by_image_name ${DOCKER_IMG_NAME}) ]]; then
+ die "It appears that there is already at least one Docker container "\
+"of image name ${DOCKER_IMG_NAME} running. Please stop it before trying again"
+fi
+
+# Build docker-in-docker image for local k8s cluster
+NO_CACHE_FLAG=""
+if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
+ [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then
+ NO_CACHE_FLAG="--no-cache"
+fi
+
+docker build ${NO_CACHE_FLAG} -t ${DOCKER_IMG_NAME} \
+ -f ${DIR}/Dockerfile.local ${DIR}
+
+
+# Attempt to start the docker container with docker, which will run the k8s
+# cluster inside.
+
+# Get current script directory
+CONTAINER_START_LOG=$(mktemp --suffix=.log)
+echo "Log file for starting cluster container: ${CONTAINER_START_LOG}"
+echo ""
+
+${DIR}/local/start_tf_cluster_container.sh \
+ ${LOCAL_K8S_CACHE} \
+ ${DOCKER_IMG_NAME} | \
+ tee ${CONTAINER_START_LOG} &
+
+# Poll start log until the k8s service is started properly or when maximum
+# attempt count is reached.
+MAX_SERVER_POLLING_ATTEMPTS=600
+
+echo "Waiting for docker-in-docker container for local k8s TensorFlow "\
+"cluster to start and launch Kubernetes..."
+
+COUNTER=0
+while true; do
+ sleep 1
+
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>=${MAX_SERVER_POLLING_ATTEMPTS}" | bc -l) == "1" ]]; then
+ die "Reached maximum number of attempts (${MAX_SERVER_POLLING_ATTEMPTS}) "\
+"while waiting for docker-in-docker for local k8s TensorFlow cluster to start"
+ fi
+
+ # Check for hitting max attempt while trying to start docker-in-docker
+ if [[ $(grep -i "Reached maximum number of attempts" \
+ "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
+ die "Docker-in-docker container for local k8s TensorFlow cluster "\
+"FAILED to start"
+ fi
+
+ if [[ $(grep -i "Local Kubernetes cluster is running" \
+ "${CONTAINER_START_LOG}" | wc -l) == "1" ]]; then
+ break
+ fi
+done
+
+# Determine the id of the docker-in-docker container
+DIND_ID=$(get_container_id_by_image_name ${DOCKER_IMG_NAME})
+
+echo "Docker-in-docker container for local k8s TensorFlow cluster has been "\
+"started successfully."
+echo "Docker-in-docker container ID: ${DIND_ID}"
+echo "Launching k8s tf cluster and tests in container ${DIND_ID} ..."
+echo ""
+
+# Launch k8s tf cluster in the docker-in-docker container and perform tests
+docker exec ${DIND_ID} \
+ /var/tf-k8s/local/test_local_tf_cluster.sh
+TEST_RES=$?
+
+# Tear down: stop docker-in-docker container
+if [[ $1 != "--leave-container-running" ]]; then
+ echo ""
+ echo "Stopping docker-in-docker container ${DIND_ID}"
+
+ docker stop --time=1 ${DIND_ID} || \
+ echo "WARNING: Failed to stop container ${DIND_ID} !!"
+
+ echo ""
+else
+ echo "Will not terminate DIND container ${DIND_ID}"
+fi
+
+if [[ "${TEST_RES}" != "0" ]]; then
+ die "Test of distributed TensorFlow runtime on docker-in-docker local "\
+"k8s cluster FAILED"
+else
+ echo "Test of distributed TensorFlow runtime on docker-in-docker local "\
+"k8s cluster PASSED"
+fi
diff --git a/tensorflow/tools/dist_test/python/mnist_replica.py b/tensorflow/tools/dist_test/python/mnist_replica.py
new file mode 100755
index 0000000000..e40aae38c2
--- /dev/null
+++ b/tensorflow/tools/dist_test/python/mnist_replica.py
@@ -0,0 +1,144 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Distributed MNIST training and validation, with model replicas.
+
+A simple softmax model with one hidden layer is defined. The parameters
+(weights and biases) are located on two parameter servers (ps), while the
+ops are defined on a worker node. The TF sessions also run on the worker
+node.
+Multiple invocations of this script can be done in parallel, with different
+values for --worker_index. There should be exactly one invocation with
+--worker_index, which will create a master session that carries out variable
+initialization. The other, non-master, sessions will wait for the master
+session to finish the initialization before proceeding to the training stage.
+
+The coordination between the multpile worker invocations occurs due to
+the definition of the parameters on the same ps devices. The parameter updates
+from one worker is visible to all other workers. As such, the workers can
+perform forward computation and gradient calculation in parallel, which
+should lead to increased training speed for the simple model.
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import sys
+import tempfile
+import time
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+import tensorflow as tf
+from tensorflow.examples.tutorials.mnist import input_data
+
+
+flags = tf.app.flags
+flags.DEFINE_string("data_dir", "/tmp/mnist-data",
+ "Directory for storing mnist data")
+flags.DEFINE_boolean("download_only", False,
+ """Only perform downloading of data; Do not proceed to
+ model definition or training""")
+flags.DEFINE_integer("worker_index", 0,
+ """Worker task index, should be >= 0. worker_index=0 is
+ the master worker task the performs the variable
+ initialization""")
+flags.DEFINE_integer("hidden_units", 100,
+ "Number of units in the hidden layer of the NN")
+flags.DEFINE_integer("train_steps", 50, "Number of training steps")
+flags.DEFINE_integer("batch_size", 100, "Training batch size")
+flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
+flags.DEFINE_string("worker_grpc_url", None,
+ "Worker GRPC URL (e.g., grpc://1.2.3.4:2222, or "
+ "grpc://tf-worker0:2222)")
+FLAGS = flags.FLAGS
+
+IMAGE_PIXELS = 28
+
+if __name__ == "__main__":
+ mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
+ if FLAGS.download_only:
+ sys.exit(0)
+
+ print("Worker GRPC URL: %s" % FLAGS.worker_grpc_url)
+ print("Worker index = %d" % FLAGS.worker_index)
+
+ with tf.Graph().as_default():
+ # Variables of the hidden layer
+ with tf.device("/job:ps/task:0"):
+ hid_w = tf.Variable(
+ tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
+ stddev=1.0 / IMAGE_PIXELS), name="hid_w")
+ hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
+
+ # Variables of the softmax layer
+ with tf.device("/job:ps/task:1"):
+ sm_w = tf.Variable(
+ tf.truncated_normal([FLAGS.hidden_units, 10],
+ stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
+ name="sm_w")
+ sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
+
+ # Ops: located on the worker specified with FLAGS.worker_index
+ with tf.device("/job:worker/task:%d" % FLAGS.worker_index):
+ x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
+ y_ = tf.placeholder(tf.float32, [None, 10])
+
+ hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
+ hid = tf.nn.relu(hid_lin)
+
+ y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
+ cross_entropy = -tf.reduce_sum(y_ *
+ tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
+ train_step = tf.train.AdamOptimizer(
+ FLAGS.learning_rate).minimize(cross_entropy)
+
+ train_dir = tempfile.mkdtemp()
+ print(FLAGS.worker_index)
+ sv = tf.train.Supervisor(logdir=train_dir,
+ is_chief=(FLAGS.worker_index == 0))
+
+ # The chief worker (worker_index==0) session will prepare the session,
+ # while the remaining workers will wait for the preparation to complete.
+ sess = sv.prepare_or_wait_for_session(FLAGS.worker_grpc_url)
+
+ # Perform training
+ time_begin = time.time()
+ print("Training begins @ %f" % time_begin)
+
+ # TODO(cais): terminate when a global step counter reaches FLAGS.train_steps
+ for i in xrange(FLAGS.train_steps):
+ # Training feed
+ batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
+ train_feed = {x: batch_xs,
+ y_: batch_ys}
+
+ sess.run(train_step, feed_dict=train_feed)
+
+ time_end = time.time()
+ print("Training ends @ %f" % time_end)
+ training_time = time_end - time_begin
+ print("Training elapsed time: %f s" % training_time)
+
+ # Validation feed
+ val_feed = {x: mnist.validation.images,
+ y_: mnist.validation.labels}
+ val_xent = sess.run(cross_entropy, feed_dict=val_feed)
+ print("After %d training step(s), validation cross entropy = %g" %
+ (FLAGS.train_steps, val_xent))
+
diff --git a/tensorflow/tools/dist_test/remote_test.sh b/tensorflow/tools/dist_test/remote_test.sh
new file mode 100755
index 0000000000..5f331c4cac
--- /dev/null
+++ b/tensorflow/tools/dist_test/remote_test.sh
@@ -0,0 +1,92 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This is the entry-point script to testing TensorFlow's distributed runtime.
+# It builds a docker image with the necessary gcloud and Kubernetes (k8s) tools
+# installed, and then execute k8s cluster preparation and distributed TensorFlow
+# runs from within a container based on the image.
+#
+# Usage:
+# remote_test.sh [--setup-cluster-only]
+# Arguments:
+# --setup-cluster-only:
+# Setup the TensorFlow k8s cluster only, and do not perform testing of
+# the distributed runtime.
+#
+#
+# If any of the following environment variable has non-empty values, it will
+# be mapped into the docker container to override the default values (see
+# dist_test.sh)
+# TF_DIST_GRPC_SERVER_URL: URL to an existing Tensorflow GRPC server.
+# If set to any non-empty and valid value (e.g.,
+# grpc://1.2.3.4:2222), it will cause the test
+# to bypass the k8s cluster setup and
+# teardown process, and just use the this URL
+# as the master session.
+# TF_DIST_GCLOUD_PROJECT: gcloud project in which the GKE cluster
+# will be created (takes effect only if
+# TF_DIST_GRPC_SERVER_URL is empty, same below)
+# TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone.
+# TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster
+# TF_DIST_GCLOUD_KEY_FILE_DIR: path to the host directory that contains
+# the gloud service key file
+# "tensorflow-testing.json"
+# TF_DIST_GRPC_PORT: port on which to create the TensorFlow GRPC
+# servers
+# TF_DIST_DOCKER_NO_CACHE: do not use cache when building docker images
+
+DOCKER_IMG_NAME="tensorflow/tf-dist-test-client"
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Prepare environment variables for the docker container
+DOCKER_ENV_FLAGS=""
+if [[ ! -z "$TF_DIST_GRPC_SERVER_URL" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
+fi
+if [[ ! -z "$TF_DIST_GCLOUD_PROJECT" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT}"
+fi
+if [[ ! -z "$TF_DIST_GCLOUD_COMPUTE_ZONE" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE}"
+fi
+if [[ ! -z "$TF_DIST_CONTAINER_CLUSTER" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER}"
+fi
+if [[ ! -z "$TF_DIST_GRPC_PORT" ]]; then
+ DOCKER_ENV_FLAGS="${DOCKER_ENV_FLAGS} "\
+"-e TF_DIST_GRPC_PORT=${TF_DIST_GRPC_PORT}"
+fi
+
+NO_CACHE_FLAG=""
+if [[ ! -z "${TF_DIST_DOCKER_NO_CACHE}" ]] &&
+ [[ "${TF_DIST_DOCKER_NO_CACHE}" != "0" ]]; then
+ NO_CACHE_FLAG="--no-cache"
+fi
+
+docker build ${NO_CACHE_FLAG} \
+ -t ${DOCKER_IMG_NAME} -f "${DIR}/Dockerfile" "${DIR}"
+KEY_FILE_DIR=${TF_DIST_GCLOUD_KEY_FILE_DIR:-"${HOME}/gcloud-secrets"}
+
+docker run -v ${KEY_FILE_DIR}:/var/gcloud/secrets \
+ ${DOCKER_ENV_FLAGS} \
+ ${DOCKER_IMG_NAME} \
+ /var/tf-dist-test/scripts/dist_test.sh $@
diff --git a/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
new file mode 100755
index 0000000000..22c0c43037
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/create_tf_cluster.sh
@@ -0,0 +1,231 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Create a Kubernetes (k8s) cluster of TensorFlow workers
+#
+# Usage:
+# create_tf_cluster.sh <num_workers> <num_parameter_servers>
+#
+# In addition, this script obeys values in the folllowing environment variables:
+# TF_DIST_LOCAL_CLUSTER: create TensorFlow cluster on local machine
+# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
+# TensorFlow (GRPC) servers with
+# TF_DIST_GCLOUD_PROJECT: gcloud project in which the GKE cluster
+# will be created (valid only if aforementioned
+# TF_DIST_GRPC_SERVER_URL is empty).
+# TF_DIST_GCLOUD_COMPUTE_ZONE: gcloud compute zone.
+# TF_DIST_CONTAINER_CLUSTER: name of the GKE cluster
+# TF_DIST_GCLOUD_KEY_FILE: if non-empty, will override GCLOUD_KEY_FILE
+# TF_DIST_GRPC_PORT: overrides the default port (2222)
+# to run the GRPC servers on
+
+# Configurations
+# gcloud operation timeout (steps)
+GCLOUD_OP_MAX_STEPS=360
+
+GRPC_PORT=${TF_DIST_GRPC_PORT:-2222}
+
+DEFAULT_GCLOUD_BIN=/var/gcloud/google-cloud-sdk/bin/gcloud
+GCLOUD_KEY_FILE=${TF_DIST_GCLOUD_KEY_FILE:-\
+"/var/gcloud/secrets/tensorflow-testing.json"}
+GCLOUD_PROJECT=${TF_DIST_GCLOUD_PROJECT:-"tensorflow-testing"}
+
+GCLOUD_COMPUTE_ZONE=${TF_DIST_GCLOUD_COMPUTE_ZONE:-"us-central1-f"}
+CONTAINER_CLUSTER=${TF_DIST_CONTAINER_CLUSTER:-"test-cluster"}
+
+SERVER_DOCKER_IMAGE=${TF_DIST_SERVER_DOCKER_IMAGE:-\
+"tensorflow/tf_grpc_test_server"}
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Get utility functions
+source "${DIR}/utils.sh"
+
+# Check input arguments
+if [[ $# != 2 ]]; then
+ die "Usage: $0 <num_workers> <num_parameter_servers>"
+fi
+
+NUM_WORKERS=$1
+NUM_PARAMETER_SERVERS=$2
+
+# Verify port string
+if [[ -z $(echo "${GRPC_PORT}" | grep -E "^[0-9]{1,5}") ]]; then
+ die "Invalid GRPC port: \"${GRPC_PORT}\""
+fi
+echo "GRPC port to be used when creating the k8s TensorFlow cluster: "\
+"${GRPC_PORT}"
+
+if [[ -z "${TF_DIST_LOCAL_CLUSTER}" ]] ||
+ [[ "${TF_DIST_LOCAL_CLUSTER}" == "0" ]]; then
+ IS_LOCAL_CLUSTER="0"
+else
+ IS_LOCAL_CLUSTER="1"
+fi
+
+if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
+ # Locate gcloud binary path
+ GCLOUD_BIN=$(which gcloud)
+ if [[ -z "${GCLOUD_BIN}" ]]; then
+ GCLOUD_BIN="${DEFAULT_GCLOUD_BIN}"
+ fi
+
+ if [[ ! -f "${GCLOUD_BIN}" ]]; then
+ die "gcloud binary cannot be found at: ${GCLOUD_BIN}"
+ fi
+ echo "Path to gcloud binary: ${GCLOUD_BIN}"
+
+ # Path to gcloud service key file
+ if [[ ! -f "${GCLOUD_KEY_FILE}" ]]; then
+ die "gcloud service account key file cannot be found at: ${GCLOUD_KEY_FILE}"
+ fi
+ echo "Path to gcloud key file: ${GCLOUD_KEY_FILE}"
+
+ echo "GCLOUD_PROJECT: ${GCLOUD_PROJECT}"
+ echo "GCLOUD_COMPUTER_ZONE: ${GCLOUD_COMPUTE_ZONE}"
+ echo "CONTAINER_CLUSTER: ${CONTAINER_CLUSTER}"
+
+ # Activate gcloud service account
+ "${GCLOUD_BIN}" auth activate-service-account --key-file "${GCLOUD_KEY_FILE}"
+
+ # Set gcloud project
+ "${GCLOUD_BIN}" config set project "${GCLOUD_PROJECT}"
+
+ # Set compute zone
+ "${GCLOUD_BIN}" config set compute/zone "${GCLOUD_COMPUTE_ZONE}"
+
+ # Set container cluster
+ "${GCLOUD_BIN}" config set container/cluster "${CONTAINER_CLUSTER}"
+
+ # Get container cluster credentials
+ "${GCLOUD_BIN}" container clusters get-credentials "${CONTAINER_CLUSTER}"
+ if [[ $? != "0" ]]; then
+ die "FAILED to get credentials for container cluster: ${CONTAINER_CLUSTER}"
+ fi
+
+ # If there is any existing tf k8s cluster, delete it first
+ "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}"
+fi
+
+# Path to kubectl binary
+KUBECTL_BIN=$(dirname "${GCLOUD_BIN}")/kubectl
+if [[ ! -f "${KUBECTL_BIN}" ]]; then
+ die "kubectl binary cannot be found at: ${KUBECTL_BIN}"
+fi
+echo "Path to kubectl binary: ${KUBECTL_BIN}"
+
+# Create yaml file for k8s TensorFlow cluster creation
+# Path to the (Python) script for generating k8s yaml file
+K8S_GEN_TF_YAML="${DIR}/k8s_tensorflow.py"
+if [[ ! -f ${K8S_GEN_TF_YAML} ]]; then
+ die "FAILED to find yaml-generating script at: ${K8S_GEN_TF_YAML}"
+fi
+
+K8S_YAML="/tmp/k8s_tf_lb.yaml"
+rm -f "${K8S_YAML}"
+
+echo ""
+echo "Generating k8s cluster yaml config file with the following settings"
+echo " Server docker image: ${SERVER_DOCKER_IMAGE}"
+echo " Number of workers: ${NUM_WORKERS}"
+echo " Number of parameter servers: ${NUM_PARAMETER_SERVERS}"
+echo " GRPC port: ${GRPC_PORT}"
+echo ""
+
+${K8S_GEN_TF_YAML} \
+ --docker_image "${SERVER_DOCKER_IMAGE}" \
+ --num_workers "${NUM_WORKERS}" \
+ --num_parameter_servers "${NUM_PARAMETER_SERVERS}" \
+ --grpc_port "${GRPC_PORT}" \
+ --request_load_balancer=True \
+ > "${K8S_YAML}" || \
+ die "Generation of the yaml configuration file for k8s cluster FAILED"
+
+if [[ ! -f "${K8S_YAML}" ]]; then
+ die "FAILED to generate yaml file for TensorFlow k8s container cluster"
+else
+ echo "Generated yaml configuration file for k8s TensorFlow cluster: "\
+"${K8S_YAML}"
+fi
+
+# Create tf k8s container cluster
+"${KUBECTL_BIN}" create -f "${K8S_YAML}"
+
+# Wait for external IP of worker services to become available
+get_tf_worker_external_ip() {
+ echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker0" | \
+ awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
+}
+
+if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
+ echo "Waiting for external IP of tf-worker0 service to emerge..."
+ echo ""
+
+ COUNTER=0
+ while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${GCLOUD_OP_MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while waiting for external IP "\
+"of tf-worker0 service to emerge"
+ fi
+
+ SVC_EXTERN_IP=$(get_tf_worker_external_ip)
+
+ if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
+ break
+ fi
+ done
+
+ GRPC_SERVER_URL="grpc://${SVC_EXTERN_IP}:${GRPC_PORT}"
+ echo "GRPC URL of tf-worker0: ${GRPC_SERVER_URL}"
+
+else
+ echo "Waiting for tf pods to be all running..."
+ echo ""
+
+ COUNTER=0
+ while true; do
+ sleep 1
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${GCLOUD_OP_MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while waiting for all tf pods to "\
+"be running in local k8s TensorFlow cluster"
+ fi
+
+ PODS_STAT=$(are_all_pods_running "${KUBECTL_BIN}")
+
+ if [[ ${PODS_STAT} == "2" ]]; then
+ # Error has occurred
+ die "Error(s) occurred while tring to launch tf k8s cluster. "\
+"One possible cause is that the Docker image used to launch the cluster is "\
+"invalid: \"${SERVER_DOCKER_IMAGE}\""
+ fi
+
+ if [[ ${PODS_STAT} == "1" ]]; then
+ break
+ fi
+ done
+
+ # Determine the tf-worker0 docker container id
+ WORKER0_ID=$(docker ps | grep "k8s_tf-worker0" | awk '{print $1}')
+ echo "WORKER0 Docker container ID: ${WORKER0_ID}"
+
+fi
+
+
+echo "Cluster setup complete."
diff --git a/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh b/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh
new file mode 100755
index 0000000000..0f96b4b57a
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/delete_tf_cluster.sh
@@ -0,0 +1,87 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script checks for any existing TensorFlow worker services, replication
+# controllers and pods in the Kubernetes (k8s) container cluster and delete
+# them if there are any.
+#
+# Usage: delete_tf_cluster [max_steps]
+#
+# max_steps: Maximum number polling steps for kubectl operations
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# Path to kubectl binary
+DEFAULT_KUBECTL_BIN=/var/gcloud/google-cloud-sdk/bin/kubectl
+KUBECTL_BIN=$(which kubectl)
+if [[ -z "${KUBECTL_BIN}" ]]; then
+ KUBECTL_BIN="${DEFAULT_KUBECTL_BIN}"
+fi
+if [[ ! -f "${KUBECTL_BIN}" ]]; then
+ die "kubectl binary cannot be found at: \"${KUBECTL_BIN}\""
+else
+ echo "Path to kubectl binary: ${KUBECTL_BIN}"
+fi
+
+MAX_STEPS=${1:-240}
+
+
+# Helper functions for kubectl workflow
+get_tf_svc_count() {
+ echo $("${KUBECTL_BIN}" get svc | grep "tf-" | wc -l)
+}
+
+get_tf_rc_count() {
+ echo $("${KUBECTL_BIN}" get rc | grep "tf-" | wc -l)
+}
+
+get_tf_pods_count() {
+ echo $("${KUBECTL_BIN}" get pods | grep "tf-" | wc -l)
+}
+
+
+# Delete all running services, replication-controllers and pods, in that order
+ITEMS_TO_DELETE="svc rc pods"
+for ITEM in ${ITEMS_TO_DELETE}; do
+ K8S_ITEM_COUNT=$(get_tf_${ITEM}_count)
+ if [[ ${K8S_ITEM_COUNT} != "0" ]]; then
+ echo "There are currently ${K8S_ITEM_COUNT} tf ${ITEM}(s) running. "
+ echo "Attempting to delete those..."
+
+ "${KUBECTL_BIN}" delete --all ${ITEM}
+
+ # Wait until all are deleted
+ # TODO(cais): Add time out
+ COUNTER=0
+ while true; do
+ sleep 1
+
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${MAX_STEPS}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while trying to delete all tf ${ITEM}"
+ fi
+
+ if [[ $(get_tf_${ITEM}_count) == "0" ]]; then
+ break
+ fi
+ done
+ fi
+
+done
diff --git a/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
new file mode 100755
index 0000000000..e0aad2b5c2
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/dist_mnist_test.sh
@@ -0,0 +1,137 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# This script invokes dist_mnist.py multiple times concurrently to test the
+# TensorFlow's distributed runtime over a Kubernetes (k8s) cluster with the
+# grpc pods and service set up.
+#
+# Usage:
+# dist_mnist_test.sh <worker_grpc_url>
+#
+# worker_grp_url is the IP address or the GRPC URL of the worker of the main
+# worker session, e.g., grpc://1.2.3.4:2222
+
+
+# Configurations
+TIMEOUT=120 # Timeout for MNIST replica sessions
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+if [[ $# != 1 ]]; then
+ die "Usage: $0 <WORKER_GRPC_URL>"
+fi
+WORKER_GRPC_URL=$1
+
+# Verify the validity of the GRPC URL
+if [[ -z $(echo "${WORKER_GRPC_URL}" | \
+ grep -E "^grpc://.+:[0-9]+") ]]; then
+ die "Invalid worker GRPC URL: \"${WORKER_GRPC_URL}\""
+fi
+
+# Current working directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+PY_DIR=$(dirname "${DIR}")/python
+
+MNIST_REPLICA="${PY_DIR}/mnist_replica.py"
+
+WKR_LOG_PREFIX="/tmp/worker"
+
+# First, download the data from a single process, to avoid race-condition
+# during data downloading
+timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
+ --download_only=True || \
+ die "Download-only step of MNIST replica FAILED"
+
+# Run a number of workers in parallel
+N_WORKERS=2
+INDICES=""
+IDX=0
+while true; do
+ timeout ${TIMEOUT} \
+ python "${MNIST_REPLICA}" \
+ --worker_grpc_url="${WORKER_GRPC_URL}" \
+ --worker_index=${IDX} 2>&1 > \
+ "${WKR_LOG_PREFIX}${IDX}.log" &
+ # TODO(cais): have each trainer process contact a different worker once
+ # supervisor and sync_replicas etc. are all working in OSS TensorFlow.
+
+ INDICES="${INDICES} ${IDX}"
+
+ ((IDX++))
+ if [[ $(echo "${IDX}==${N_WORKERS}" | bc -l) == "1" ]]; then
+ break
+ fi
+done
+
+# Function for getting final validation cross entropy from worker log files
+get_final_val_xent() {
+ echo $(cat $1 | grep "^After.*validation cross entropy = " | \
+ awk '{print $NF}')
+}
+
+# Poll until all final validation cross entropy values become available or
+# operation times out
+COUNTER=0
+while true; do
+ ((COUNTER++))
+ if [[ $(echo "${COUNTER}>${TIMEOUT}" | bc -l) == "1" ]]; then
+ die "Reached maximum polling steps while polling for final validation "\
+"cross entropies from all workers"
+ fi
+
+ N_AVAIL=0
+ VAL_XENTS=""
+ for N in ${INDICES}; do
+ VAL_XENT=$(get_final_val_xent "${WKR_LOG_PREFIX}${N}.log")
+ if [[ ! -z ${VAL_XENT} ]]; then
+ ((N_AVAIL++))
+ VAL_XENTS="${VAL_XENTS} ${VAL_XENT}"
+ fi
+ done
+
+ if [[ "${N_AVAIL}" == "2" ]]; then
+ # Print out the content of the log files
+ for M in ${INDICES}; do
+ echo "==================================================="
+ echo "=== Log file from worker ${M} ==="
+ cat "${WKR_LOG_PREFIX}${M}.log"
+ echo "==================================================="
+ echo ""
+ done
+
+ break
+ else
+ sleep 1
+ fi
+done
+
+# Sanity check on the validation entropies
+# TODO(cais): In addition to this basic sanity check, we could run the training
+# with 1 and 2 workers, each for a few times and use scipy.stats to do a t-test
+# to verify tha tthe 2-worker training gives significantly lower final cross
+# entropy
+VAL_XENTS=(${VAL_XENTS})
+for N in ${INDICES}; do
+ echo "Final validation cross entropy from worker${N}: ${VAL_XENTS[N]}"
+ if [[ $(echo "${VAL_XENTS[N]}>0" | bc -l) != "1" ]]; then
+ die "Sanity checks on the final validation cross entropy values FAILED"
+ fi
+
+done
diff --git a/tensorflow/tools/dist_test/scripts/dist_test.sh b/tensorflow/tools/dist_test/scripts/dist_test.sh
new file mode 100755
index 0000000000..f8ade7eff8
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/dist_test.sh
@@ -0,0 +1,118 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Performs tests of TensorFlow's distributed runtime over a Kubernetes (k8s)
+# container cluster.
+#
+# This script tears down any existing TensorFlow cluster, consisting of
+# services, replication controllers and pods, before creating a new cluster.
+# The cluster containers a number of parameter server services and a number of
+# worker services. The paramater servers will hold parameters of the ML model,
+# e.g., weights and biases of the NN layers, while the workers will hold the
+# TensorFlow ops.
+#
+# Usage:
+# dist_test.sh [--setup-cluster-only]
+#
+# --setup-cluster-only lets the script only set up the k8s container network
+#
+# This script obeys values in the folllowing environment variables:
+# TF_DIST_GRPC_SERVER_URL: If it is set to a valid grpc server url (e.g.,
+# (grpc://1.2.3.4:2222), the script will bypass
+# the cluster setup and teardown processes and
+# just use this URL.
+
+
+# Configurations
+NUM_WORKERS=2 # Number of worker container
+NUM_PARAMETER_SERVERS=2 # Number of parameter servers
+
+# Helper functions
+die() {
+ echo $@
+ exit 1
+}
+
+# gcloud operation timeout (steps)
+GCLOUD_OP_MAX_STEPS=240
+
+GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}
+
+# Report gcloud / GKE parameters
+echo "GRPC_SERVER_URL: ${GRPC_SERVER_URL}"
+
+# Get current script directory
+DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+
+# Locate path to kubectl binary
+TEARDOWN_WHEN_DONE=1
+if [[ ! -z "${GRPC_SERVER_URL}" ]]; then
+ TEARDOWN_WHEN_DONE=0
+ # Verify the validity of the GRPC URL
+ if [[ -z $(echo "${GRPC_SERVER_URL}" | \
+ grep -E "^grpc://.+:[0-9]+") ]]; then
+ die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\""
+ else
+ echo "The preset GRPC_SERVER_URL appears to be valid: ${GRPC_SERVER_URL}"
+ echo "Will bypass the TensorFlow k8s cluster setup and teardown process"
+ echo ""
+ fi
+else
+ TMP=$(mktemp)
+ "${DIR}/create_tf_cluster.sh" ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} 2>&1 | \
+ tee "${TMP}" || \
+ die "Creation of TensorFlow k8s cluster FAILED"
+
+ GRPC_SERVER_URL=$(cat ${TMP} | grep "GRPC URL of tf-worker0: .*" | \
+ awk '{print $NF}')
+ if [[ -z "${GRPC_SERVER_URL}" ]]; then
+ die "FAILED to determine GRPC server URL"
+ fi
+ rm -f ${TMP}
+
+ if [[ $1 == "--setup-cluster-only" ]]; then
+ echo "Skipping testing of distributed runtime due to "\
+"option flag --setup-cluster-only"
+ exit 0
+ fi
+fi
+
+# Invoke script to perform distributed MNIST training
+MNIST_DIST_TEST_BIN="${DIR}/dist_mnist_test.sh"
+if [[ ! -f "${MNIST_DIST_TEST_BIN}" ]]; then
+ die "FAILED to find distributed mnist client test script at "\
+"${MNIST_DIST_TEST_BIN}"
+fi
+
+echo "Performing distributed MNIST training through grpc session @ "\
+"${GRPC_SERVER_URL}..."
+
+"${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URL}"
+
+if [[ $? == "0" ]]; then
+ echo "MNIST-replica test PASSED"
+else
+ die "MNIST-replica test FAILED"
+fi
+
+# Tear down current k8s TensorFlow cluster
+if [[ "${TEARDOWN_WHEN_DONE}" == "1" ]]; then
+ echo "Tearing down k8s TensorFlow cluster..."
+ "${DIR}/delete_tf_cluster.sh" "${GCLOUD_OP_MAX_STEPS}" && \
+ echo "Cluster tear-down SUCCEEDED" || \
+ die "Cluster tear-down FAILED"
+fi
+echo "SUCCESS: Test of distributed TensorFlow runtime PASSED"
diff --git a/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
new file mode 100755
index 0000000000..e3fde2180a
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/k8s_tensorflow.py
@@ -0,0 +1,245 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generates YAML configuration files for distributed Tensorflow workers.
+
+The workers will be run in a Kubernetes (k8s) container cluster.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+
+# Note: It is intentional that we do not import tensorflow in this script. The
+# machine that launches a TensorFlow k8s cluster does not have to have the
+# Python package of TensorFlow installed on it.
+
+
+DEFAULT_DOCKER_IMAGE = 'tensorflow/tf_grpc_test_server'
+DEFAULT_PORT = 2222
+
+# TODO(cais): Consider adding resource requests/limits to the pods.
+WORKER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: tf-worker{worker_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-worker: "{worker_id}"
+ spec:
+ containers:
+ - name: tf-worker{worker_id}
+ image: {docker_image}
+ args:
+ - --cluster_spec={cluster_spec}
+ - --job_name=worker
+ - --task_id={worker_id}
+ ports:
+ - containerPort: {port}
+""")
+WORKER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ ports:
+ - port: {port}
+ targetPort: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+WORKER_LB_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-worker{worker_id}
+ labels:
+ tf-worker: "{worker_id}"
+spec:
+ type: LoadBalancer
+ ports:
+ - port: {port}
+ selector:
+ tf-worker: "{worker_id}"
+""")
+PARAM_SERVER_RC = (
+ """apiVersion: v1
+kind: ReplicationController
+metadata:
+ name: tf-ps{param_server_id}
+spec:
+ replicas: 1
+ template:
+ metadata:
+ labels:
+ tf-ps: "{param_server_id}"
+ spec:
+ containers:
+ - name: tf-ps{param_server_id}
+ image: {docker_image}
+ args:
+ - --cluster_spec={cluster_spec}
+ - --job_name=ps
+ - --task_id={param_server_id}
+ ports:
+ - containerPort: {port}
+""")
+PARAM_SERVER_SVC = (
+ """apiVersion: v1
+kind: Service
+metadata:
+ name: tf-ps{param_server_id}
+ labels:
+ tf-ps: "{param_server_id}"
+spec:
+ ports:
+ - port: {port}
+ selector:
+ tf-ps: "{param_server_id}"
+""")
+
+
+def main():
+ """Do arg parsing."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--num_workers',
+ type=int,
+ default=2,
+ help='How many worker pods to run')
+ parser.add_argument('--num_parameter_servers',
+ type=int,
+ default=1,
+ help='How many paramater server pods to run')
+ parser.add_argument('--grpc_port',
+ type=int,
+ default=DEFAULT_PORT,
+ help='GRPC server port (Default: %d)' % DEFAULT_PORT)
+ parser.add_argument('--request_load_balancer',
+ type=bool,
+ default=False,
+ help='To request worker0 to be exposed on a public IP '
+ 'address via an external load balancer, enabling you to '
+ 'run client processes from outside the cluster')
+ parser.add_argument('--docker_image',
+ type=str,
+ default=DEFAULT_DOCKER_IMAGE,
+ help='Override default docker image for the TensorFlow '
+ 'GRPC server')
+ args = parser.parse_args()
+
+ if args.num_workers <= 0:
+ sys.stderr.write('--num_workers must be greater than 0; received %d\n'
+ % args.num_workers)
+ sys.exit(1)
+ if args.num_parameter_servers <= 0:
+ sys.stderr.write(
+ '--num_parameter_servers must be greater than 0; received %d\n'
+ % args.num_parameter_servers)
+ sys.exit(1)
+
+ # Generate contents of yaml config
+ yaml_config = GenerateConfig(args.num_workers,
+ args.num_parameter_servers,
+ args.grpc_port,
+ args.request_load_balancer,
+ args.docker_image)
+ print(yaml_config) # pylint: disable=superfluous-parens
+
+
+def GenerateConfig(num_workers,
+ num_param_servers,
+ port,
+ request_load_balancer,
+ docker_image):
+ """Generate configuration strings."""
+ config = ''
+ for worker in range(num_workers):
+ config += WORKER_RC.format(
+ port=port,
+ worker_id=worker,
+ docker_image=docker_image,
+ cluster_spec=WorkerClusterSpec(num_workers,
+ num_param_servers,
+ port))
+ config += '---\n'
+ if worker == 0 and request_load_balancer:
+ config += WORKER_LB_SVC.format(port=port,
+ worker_id=worker)
+ else:
+ config += WORKER_SVC.format(port=port,
+ worker_id=worker)
+ config += '---\n'
+
+ for param_server in range(num_param_servers):
+ config += PARAM_SERVER_RC.format(
+ port=port,
+ param_server_id=param_server,
+ docker_image=docker_image,
+ cluster_spec=ParamServerClusterSpec(num_workers,
+ num_param_servers,
+ port))
+ config += '---\n'
+ config += PARAM_SERVER_SVC.format(port=port,
+ param_server_id=param_server)
+ config += '---\n'
+
+ return config
+
+
+def WorkerClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates worker cluster spec."""
+ return ClusterSpec(num_workers, num_param_servers, port)
+
+
+def ParamServerClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates parameter server spec."""
+ return ClusterSpec(num_workers, num_param_servers, port)
+
+
+def ClusterSpec(num_workers,
+ num_param_servers,
+ port):
+ """Generates general cluster spec."""
+ spec = 'worker|'
+ for worker in range(num_workers):
+ spec += 'tf-worker%d:%d' % (worker, port)
+ if worker != num_workers-1:
+ spec += ';'
+
+ spec += ',ps|'
+ for param_server in range(num_param_servers):
+ spec += 'tf-ps%d:%d' % (param_server, port)
+ if param_server != num_param_servers-1:
+ spec += ';'
+
+ return spec
+
+
+if __name__ == '__main__':
+ main()
diff --git a/tensorflow/tools/dist_test/scripts/utils.sh b/tensorflow/tools/dist_test/scripts/utils.sh
new file mode 100644
index 0000000000..bc4485baf0
--- /dev/null
+++ b/tensorflow/tools/dist_test/scripts/utils.sh
@@ -0,0 +1,56 @@
+#!/usr/bin/env bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Utility functions for dist_test scripts
+
+
+# Print info and exit with code 1
+die() {
+ echo $@
+ exit 1
+}
+
+
+# Determine if all k8s pods in a namespace are all in the "Running" state
+are_all_pods_running() {
+ # Usage: are_all_pods_running <KUBECTL_BIN> [namespace]
+ KUBECTL_BIN=$1
+
+ if [[ -z "$2" ]]; then
+ NS_FLAG=""
+ else
+ NS_FLAG="--namespace=$2"
+ fi
+
+ sleep 1 # Wait for the status to settle
+ NPODS=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | wc -l)
+ NRUNNING=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | \
+ grep "Running" | wc -l)
+ NERR=$("${KUBECTL_BIN}" "${NS_FLAG}" get pods | tail -n +2 | \
+ grep "Err" | wc -l)
+
+ if [[ ${NERR} != "0" ]]; then
+ # "2" signifies that error has occurred
+ echo "2"
+ elif [[ ${NPODS} == ${NRUNNING} ]]; then
+ # "1" signifies that all pods are in Running state
+ echo "1"
+ else
+ # "0" signifies that some pods have not entered Running state, but
+ # no error has occurred
+ echo "0"
+ fi
+}
diff --git a/tensorflow/tools/dist_test/server/Dockerfile b/tensorflow/tools/dist_test/server/Dockerfile
new file mode 100644
index 0000000000..bf384413f1
--- /dev/null
+++ b/tensorflow/tools/dist_test/server/Dockerfile
@@ -0,0 +1,59 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Test server for TensorFlow GRPC server
+#
+# To build the image, use ../build_server.sh
+
+FROM ubuntu:14.04
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y \
+ bc \
+ curl \
+ dnsutils \
+ python-numpy \
+ python-pip \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
+ python get-pip.py && \
+ rm get-pip.py
+
+# Install TensorFlow CPU version.
+RUN pip --no-cache-dir install \
+ http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
+
+# Copy files, including the GRPC server binary at
+# server/grpc_tensorflow_server.py
+ADD . /var/tf-k8s
+
+# Download MNIST data for tests
+RUN mkdir -p /tmp/mnist-data
+RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
+RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
+RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
+RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
+ http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
+
+# Container entry point
+ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server.py"]
diff --git a/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
new file mode 100755
index 0000000000..b9742112de
--- /dev/null
+++ b/tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
@@ -0,0 +1,122 @@
+#!/usr/bin/python
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Python-based TensorFlow GRPC server.
+
+Takes input arguments cluster_spec, job_name and task_id, and start a blocking
+TensorFlow GRPC server.
+
+Usage:
+ grpc_tensorflow_server.py --cluster_spec=SPEC --job_name=NAME --task_id=ID
+
+Where:
+ SPEC is <JOB>(,<JOB>)*
+ JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*
+ NAME is a valid job name ([a-z][0-9a-z]*)
+ HOST is a hostname or IP address
+ PORT is a port number
+"""
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string("cluster_spec", "",
+ """Cluster spec: SPEC.
+ SPEC is <JOB>(,<JOB>)*,"
+ JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
+ NAME is a valid job name ([a-z][0-9a-z]*),"
+ HOST is a hostname or IP address,"
+ PORT is a port number."
+E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""")
+tf.app.flags.DEFINE_string("job_name", "", "Job name: e.g., local")
+tf.app.flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
+tf.app.flags.DEFINE_boolean("verbose", False, "Verbose mode")
+
+
+def parse_cluster_spec(cluster_spec, cluster):
+ """Parse content of cluster_spec string and inject info into cluster protobuf.
+
+ Args:
+ cluster_spec: cluster specification string, e.g.,
+ "local|localhost:2222;localhost:2223"
+ cluster: cluster protobuf.
+
+ Raises:
+ ValueError: if the cluster_spec string is invalid.
+ """
+
+ job_strings = cluster_spec.split(",")
+
+ for job_string in job_strings:
+ job_def = cluster.job.add()
+
+ if job_string.count("|") != 1:
+ raise ValueError("Not exactly one instance of '|' in cluster_spec")
+
+ job_name = job_string.split("|")[0]
+
+ if not job_name:
+ raise ValueError("Empty job_name in cluster_spec")
+
+ job_def.name = job_name
+
+ if FLAGS.verbose:
+ print("Added job named \"%s\"" % job_name)
+
+ job_tasks = job_string.split("|")[1].split(";")
+ for i in range(len(job_tasks)):
+ if not job_tasks[i]:
+ raise ValueError("Empty job_task string at position %d" % i)
+
+ job_def.tasks[i] = job_tasks[i]
+
+ if FLAGS.verbose:
+ print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))
+
+
+def main(unused_args):
+ # Create Protobuf ServerDef
+ server_def = tf.ServerDef(protocol="grpc")
+
+ # Cluster info
+ parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
+
+ # Job name
+ if not FLAGS.job_name:
+ raise ValueError("Empty job_name")
+ server_def.job_name = FLAGS.job_name
+
+ # Task index
+ if FLAGS.task_id < 0:
+ raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
+ server_def.task_index = FLAGS.task_id
+
+ # Create GrpcServer instance
+ server = tf.GrpcServer(server_def)
+
+ # join() is blocking, unlike start()
+ server.join()
+
+
+if __name__ == "__main__":
+ tf.app.run()