diff options
author | Shanqing Cai <cais@google.com> | 2016-07-10 06:13:10 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-07-10 07:17:28 -0700 |
commit | 213fb1efea948483c5bade7635164780f2135dd0 (patch) | |
tree | cab68f839fb1a4a6fd2279f7bb447e0f43aaf894 /tensorflow/tools/gcs_test | |
parent | afc071cd4f96b34a7d5e68ce7b4bf09d96279803 (diff) |
End-to-end smoke test for GCS access from OSS TensorFlow
Change: 127027579
Diffstat (limited to 'tensorflow/tools/gcs_test')
-rw-r--r-- | tensorflow/tools/gcs_test/Dockerfile | 22 | ||||
-rwxr-xr-x | tensorflow/tools/gcs_test/gcs_smoke.sh | 96 | ||||
-rw-r--r-- | tensorflow/tools/gcs_test/python/gcs_smoke.py | 112 |
3 files changed, 230 insertions, 0 deletions
diff --git a/tensorflow/tools/gcs_test/Dockerfile b/tensorflow/tools/gcs_test/Dockerfile new file mode 100644 index 0000000000..43e7da7743 --- /dev/null +++ b/tensorflow/tools/gcs_test/Dockerfile @@ -0,0 +1,22 @@ +FROM ubuntu:14.04 + +MAINTAINER Shanqing Cai <cais@google.com> + +RUN apt-get update +RUN apt-get install -y --no-install-recommends \ + curl \ + python \ + python-numpy \ + python-pip + +# Install Google Cloud SDK +RUN curl -O https://dl.google.com/dl/cloudsdk/channels/rapid/install_google_cloud_sdk.bash +RUN chmod +x install_google_cloud_sdk.bash +RUN ./install_google_cloud_sdk.bash --disable-prompts --install-dir=/var/gcloud + +# Install nightly TensorFlow pip +RUN pip install \ + http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.9.0-py2-none-any.whl + +# Copy test files +COPY python/gcs_smoke.py / diff --git a/tensorflow/tools/gcs_test/gcs_smoke.sh b/tensorflow/tools/gcs_test/gcs_smoke.sh new file mode 100755 index 0000000000..df59781ea8 --- /dev/null +++ b/tensorflow/tools/gcs_test/gcs_smoke.sh @@ -0,0 +1,96 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Driver script for TensorFlow-GCS smoke test. +# +# Usage: +# gcs_smoke.sh <GCLOUD_JSON_KEY_PATH> <GCS_BUCKET_URL> +# +# Input arguments: +# GCLOUD_KEY_JSON_PATH: Path to the Google Cloud JSON key file. +# See https://cloud.google.com/storage/docs/authentication for details. +# +# GCS_BUCKET_URL: URL to the GCS bucket for testing. +# E.g., gs://my-gcs-bucket/test-directory + +# Configurations +DOCKER_IMG="tensorflow-gcs-test" + +print_usage() { + echo "Usage: gcs_smoke.sh <GCLOUD_JSON_KEY_PATH> <GCS_BUCKET_URL>" + echo "" +} + + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "${SCRIPT_DIR}/../ci_build/builds/builds_common.sh" + +# Check input arguments +GCLOUD_JSON_KEY_PATH=$1 +GCS_BUCKET_URL=$2 +if [[ -z "${GCLOUD_JSON_KEY_PATH}" ]]; then + print_usage + die "ERROR: Command-line argument GCLOUD_JSON_KEY_PATH is not supplied" +fi +if [[ -z "${GCS_BUCKET_URL}" ]]; then + print_usage + die "ERROR: Command-line argument GCS_BUCKET_URL is not supplied" +fi + +if [[ ! -f "${GCLOUD_JSON_KEY_PATH}" ]]; then + die "ERROR: Path to Google Cloud JSON key file is invalid: \""\ +"${GCLOUD_JSON_KEY_PATH}\"" +fi + +DOCKERFILE="${SCRIPT_DIR}/Dockerfile" +if [[ ! -f "${DOCKERFILE}" ]]; then + die "ERROR: Cannot find Dockerfile at expected path ${DOCKERFILE}" +fi + +# Build the docker image for testing +docker build --no-cache \ + -f "${DOCKERFILE}" -t "${DOCKER_IMG}" "${SCRIPT_DIR}" || \ + die "FAIL: Failed to build docker image for testing" + +# Run the docker image with the GCS key file mapped and the gcloud-required +# environment variables set. +LOG_FILE="/tmp/tf-gcs-test.log" +rm -rf ${LOG_FILE} + +docker run --rm \ + -v ${GCLOUD_JSON_KEY_PATH}:/gcloud-key.json \ + -e "GOOGLE_APPLICATION_CREDENTIALS=/gcloud-key.json" \ + "${DOCKER_IMG}" \ + python /gcs_smoke.py --gcs_bucket_url="${GCS_BUCKET_URL}" \ + 2>&1 > "${LOG_FILE}" + +if [[ $? != "0" ]]; then + cat ${LOG_FILE} + die "FAIL: End-to-end test of GCS access from TensorFlow failed." +fi + +cat ${LOG_FILE} +echo "" + +# Clean up the newly created tfrecord file in GCS bucket +NEW_TFREC_URL=$(grep "Using input path" "${LOG_FILE}" | \ + awk '{print $NF}') +if [[ -z ${NEW_TFREC_URL} ]]; then + die "FAIL: Unable to determine the URL to the new tfrecord file in GCS" +fi +gsutil rm "${NEW_TFREC_URL}" && \ + echo "Cleaned up new tfrecord file in GCS: ${NEW_TFREC_URL}" || \ + die "FAIL: Unable to clean up new tfrecord file in GCS: ${NEW_TFREC_URL}"
\ No newline at end of file diff --git a/tensorflow/tools/gcs_test/python/gcs_smoke.py b/tensorflow/tools/gcs_test/python/gcs_smoke.py new file mode 100644 index 0000000000..90d32dc149 --- /dev/null +++ b/tensorflow/tools/gcs_test/python/gcs_smoke.py @@ -0,0 +1,112 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Smoke test for reading records from GCS to TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.core.example import example_pb2 + +flags = tf.app.flags +flags.DEFINE_string("gcs_bucket_url", "", + "The URL to the GCS bucket in which the temporary " + "tfrecord file is to be written and read, e.g., " + "gs://my-gcs-bucket/test-directory") +flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") + +FLAGS = flags.FLAGS + + +def create_examples(num_examples, input_mean): + """Create ExampleProto's containg data.""" + ids = np.arange(num_examples).reshape([num_examples, 1]) + inputs = np.random.randn(num_examples, 1) + input_mean + target = inputs - input_mean + examples = [] + for row in range(num_examples): + ex = example_pb2.Example() + ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0])) + ex.features.feature["target"].float_list.value.append(target[row, 0]) + ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) + examples.append(ex) + return examples + + +if __name__ == "__main__": + # Sanity check on the GCS bucket URL. + if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): + print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) + sys.exit(1) + + # Generate random tfrecord path name. + input_path = FLAGS.gcs_bucket_url + "/" + input_path += "".join(random.choice("0123456789ABCDEF") for i in range(8)) + input_path += ".tfrecord" + print("Using input path: %s" % input_path) + + # Verify that writing to the records file in GCS works. + print("\n=== Testing writing and reading of GCS record file... ===") + example_data = create_examples(FLAGS.num_examples, 5) + with tf.python_io.TFRecordWriter(input_path) as hf: + for e in example_data: + hf.write(e.SerializeToString()) + + print("Data written to: %s" % input_path) + + # Verify that reading from the tfrecord file works and that + # tf_record_iterator works. + record_iter = tf.python_io.tf_record_iterator(input_path) + read_count = 0 + for r in record_iter: + read_count += 1 + print("Read %d records using tf_record_iterator" % read_count) + + if read_count != FLAGS.num_examples: + print("FAIL: The number of records read from tf_record_iterator (%d) " + "differs from the expected number (%d)" % (read_count, + FLAGS.num_examples)) + sys.exit(1) + + # Verify that running the read op in a session works. + print("\n=== Testing TFRecordReader.read op in a session... ===") + with tf.Graph().as_default() as g: + filename_queue = tf.train.string_input_producer([input_path], num_epochs=1) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + + with tf.Session() as sess: + sess.run(tf.initialize_all_variables()) + sess.run(tf.initialize_local_variables()) + tf.train.start_queue_runners() + index = 0 + for _ in range(FLAGS.num_examples): + print("Read record: %d" % index) + sess.run(serialized_example) + index += 1 + + # Reading one more record should trigger an exception. + try: + sess.run(serialized_example) + print("FAIL: Failed to catch the expected OutOfRangeError while " + "reading one more record than is available") + sys.exit(1) + except tf.python.framework.errors.OutOfRangeError: + print("Successfully caught the expected OutOfRangeError while " + "reading one more record than is available") |