aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/gcs_test
diff options
context:
space:
mode:
authorGravatar Shanqing Cai <cais@google.com>2016-07-10 06:13:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-10 07:17:28 -0700
commit213fb1efea948483c5bade7635164780f2135dd0 (patch)
treecab68f839fb1a4a6fd2279f7bb447e0f43aaf894 /tensorflow/tools/gcs_test
parentafc071cd4f96b34a7d5e68ce7b4bf09d96279803 (diff)
End-to-end smoke test for GCS access from OSS TensorFlow
Change: 127027579
Diffstat (limited to 'tensorflow/tools/gcs_test')
-rw-r--r--tensorflow/tools/gcs_test/Dockerfile22
-rwxr-xr-xtensorflow/tools/gcs_test/gcs_smoke.sh96
-rw-r--r--tensorflow/tools/gcs_test/python/gcs_smoke.py112
3 files changed, 230 insertions, 0 deletions
diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile
new file mode 100644
index 0000000000..43e7da7743
--- /dev/null
+++ b/tensorflow/tools/gcs_test/Dockerfile
@@ -0,0 +1,22 @@
+FROM ubuntu:14.04
+
+MAINTAINER Shanqing Cai <cais@google.com>
+
+RUN apt-get update
+RUN apt-get install -y --no-install-recommends \
+ curl \
+ python \
+ python-numpy \
+ python-pip
+
+# Install Google Cloud SDK
+RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/install_google_cloud_sdk.bash
+RUN chmod +x install_google_cloud_sdk.bash
+RUN ./install_google_cloud_sdk.bash --disable-prompts --install-dir=/var/gcloud
+
+# Install nightly TensorFlow pip
+RUN pip install \
+ http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl
+
+# Copy test files
+COPY python/gcs_smoke.py /
diff --git a/tensorflow/tools/gcs_test/gcs_smoke.sh b/tensorflow/tools/gcs_test/gcs_smoke.sh
new file mode 100755
index 0000000000..df59781ea8
--- /dev/null
+++ b/tensorflow/tools/gcs_test/gcs_smoke.sh
@@ -0,0 +1,96 @@
+#!/usr/bin/env bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+#
+# Driver script for TensorFlow-GCS smoke test.
+#
+# Usage:
+# gcs_smoke.sh <GCLOUD_JSON_KEY_PATH> <GCS_BUCKET_URL>
+#
+# Input arguments:
+# GCLOUD_KEY_JSON_PATH: Path to the Google Cloud JSON key file.
+# See https://cloud.google.com/storage/docs/authentication for details.
+#
+# GCS_BUCKET_URL: URL to the GCS bucket for testing.
+# E.g., gs://my-gcs-bucket/test-directory
+
+# Configurations
+DOCKER_IMG="tensorflow-gcs-test"
+
+print_usage() {
+ echo "Usage: gcs_smoke.sh <GCLOUD_JSON_KEY_PATH> <GCS_BUCKET_URL>"
+ echo ""
+}
+
+
+SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
+source "${SCRIPT_DIR}/../ci_build/builds/builds_common.sh"
+
+# Check input arguments
+GCLOUD_JSON_KEY_PATH=$1
+GCS_BUCKET_URL=$2
+if [[ -z "${GCLOUD_JSON_KEY_PATH}" ]]; then
+ print_usage
+ die "ERROR: Command-line argument GCLOUD_JSON_KEY_PATH is not supplied"
+fi
+if [[ -z "${GCS_BUCKET_URL}" ]]; then
+ print_usage
+ die "ERROR: Command-line argument GCS_BUCKET_URL is not supplied"
+fi
+
+if [[ ! -f "${GCLOUD_JSON_KEY_PATH}" ]]; then
+ die "ERROR: Path to Google Cloud JSON key file is invalid: \""\
+"${GCLOUD_JSON_KEY_PATH}\""
+fi
+
+DOCKERFILE="${SCRIPT_DIR}/Dockerfile"
+if [[ ! -f "${DOCKERFILE}" ]]; then
+ die "ERROR: Cannot find Dockerfile at expected path ${DOCKERFILE}"
+fi
+
+# Build the docker image for testing
+docker build --no-cache \
+ -f "${DOCKERFILE}" -t "${DOCKER_IMG}" "${SCRIPT_DIR}" || \
+ die "FAIL: Failed to build docker image for testing"
+
+# Run the docker image with the GCS key file mapped and the gcloud-required
+# environment variables set.
+LOG_FILE="/tmp/tf-gcs-test.log"
+rm -rf ${LOG_FILE}
+
+docker run --rm \
+ -v ${GCLOUD_JSON_KEY_PATH}:/gcloud-key.json \
+ -e "GOOGLE_APPLICATION_CREDENTIALS=/gcloud-key.json" \
+ "${DOCKER_IMG}" \
+ python /gcs_smoke.py --gcs_bucket_url="${GCS_BUCKET_URL}" \
+ 2>&1 > "${LOG_FILE}"
+
+if [[ $? != "0" ]]; then
+ cat ${LOG_FILE}
+ die "FAIL: End-to-end test of GCS access from TensorFlow failed."
+fi
+
+cat ${LOG_FILE}
+echo ""
+
+# Clean up the newly created tfrecord file in GCS bucket
+NEW_TFREC_URL=$(grep "Using input path" "${LOG_FILE}" | \
+ awk '{print $NF}')
+if [[ -z ${NEW_TFREC_URL} ]]; then
+ die "FAIL: Unable to determine the URL to the new tfrecord file in GCS"
+fi
+gsutil rm "${NEW_TFREC_URL}" && \
+ echo "Cleaned up new tfrecord file in GCS: ${NEW_TFREC_URL}" || \
+ die "FAIL: Unable to clean up new tfrecord file in GCS: ${NEW_TFREC_URL}" \ No newline at end of file
diff --git a/tensorflow/tools/gcs_test/python/gcs_smoke.py b/tensorflow/tools/gcs_test/python/gcs_smoke.py
new file mode 100644
index 0000000000..90d32dc149
--- /dev/null
+++ b/tensorflow/tools/gcs_test/python/gcs_smoke.py
@@ -0,0 +1,112 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Smoke test for reading records from GCS to TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+import sys
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.core.example import example_pb2
+
+flags = tf.app.flags
+flags.DEFINE_string("gcs_bucket_url", "",
+ "The URL to the GCS bucket in which the temporary "
+ "tfrecord file is to be written and read, e.g., "
+ "gs://my-gcs-bucket/test-directory")
+flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
+
+FLAGS = flags.FLAGS
+
+
+def create_examples(num_examples, input_mean):
+ """Create ExampleProto's containg data."""
+ ids = np.arange(num_examples).reshape([num_examples, 1])
+ inputs = np.random.randn(num_examples, 1) + input_mean
+ target = inputs - input_mean
+ examples = []
+ for row in range(num_examples):
+ ex = example_pb2.Example()
+ ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0]))
+ ex.features.feature["target"].float_list.value.append(target[row, 0])
+ ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
+ examples.append(ex)
+ return examples
+
+
+if __name__ == "__main__":
+ # Sanity check on the GCS bucket URL.
+ if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
+ print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
+ sys.exit(1)
+
+ # Generate random tfrecord path name.
+ input_path = FLAGS.gcs_bucket_url + "/"
+ input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8))
+ input_path += ".tfrecord"
+ print("Using input path: %s" % input_path)
+
+ # Verify that writing to the records file in GCS works.
+ print("\n=== Testing writing and reading of GCS record file... ===")
+ example_data = create_examples(FLAGS.num_examples, 5)
+ with tf.python_io.TFRecordWriter(input_path) as hf:
+ for e in example_data:
+ hf.write(e.SerializeToString())
+
+ print("Data written to: %s" % input_path)
+
+ # Verify that reading from the tfrecord file works and that
+ # tf_record_iterator works.
+ record_iter = tf.python_io.tf_record_iterator(input_path)
+ read_count = 0
+ for r in record_iter:
+ read_count += 1
+ print("Read %d records using tf_record_iterator" % read_count)
+
+ if read_count != FLAGS.num_examples:
+ print("FAIL: The number of records read from tf_record_iterator (%d) "
+ "differs from the expected number (%d)" % (read_count,
+ FLAGS.num_examples))
+ sys.exit(1)
+
+ # Verify that running the read op in a session works.
+ print("\n=== Testing TFRecordReader.read op in a session... ===")
+ with tf.Graph().as_default() as g:
+ filename_queue = tf.train.string_input_producer([input_path], num_epochs=1)
+ reader = tf.TFRecordReader()
+ _, serialized_example = reader.read(filename_queue)
+
+ with tf.Session() as sess:
+ sess.run(tf.initialize_all_variables())
+ sess.run(tf.initialize_local_variables())
+ tf.train.start_queue_runners()
+ index = 0
+ for _ in range(FLAGS.num_examples):
+ print("Read record: %d" % index)
+ sess.run(serialized_example)
+ index += 1
+
+ # Reading one more record should trigger an exception.
+ try:
+ sess.run(serialized_example)
+ print("FAIL: Failed to catch the expected OutOfRangeError while "
+ "reading one more record than is available")
+ sys.exit(1)
+ except tf.python.framework.errors.OutOfRangeError:
+ print("Successfully caught the expected OutOfRangeError while "
+ "reading one more record than is available")