aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-11 21:03:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-11 22:11:51 -0700
commited65e69560a8e2d58f7571fc2dcac269b20c260d (patch)
tree93dfda689ef0c34617f74932c352a4fa44b14224
parentd03631a27a0a3bc3eb23faa630b60c8e9826e1be (diff)
File system implementation for Google Cloud Storage.
This code implements a file system for file paths starting with gs:// using the HTTP API to Google Cloud Storage. No authentication is implemented yet, so only GCS objects with public access can be used. Change: 122126085
-rwxr-xr-xconfigure31
-rw-r--r--jsoncpp.BUILD34
-rw-r--r--tensorflow/BUILD1
-rw-r--r--tensorflow/core/BUILD5
-rw-r--r--tensorflow/core/platform/cloud/BUILD79
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc516
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h73
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc363
-rw-r--r--tensorflow/core/platform/cloud/http_request.cc410
-rw-r--r--tensorflow/core/platform/cloud/http_request.h156
-rw-r--r--tensorflow/core/platform/cloud/http_request_test.cc323
-rw-r--r--tensorflow/core/platform/default/build_config.bzl7
-rwxr-xr-xtensorflow/tools/ci_build/install/install_deb_packages.sh1
-rw-r--r--tensorflow/workspace.bzl12
14 files changed, 2009 insertions, 2 deletions
diff --git a/configure b/configure
index 3f9a53b573..5d6397da57 100755
--- a/configure
+++ b/configure
@@ -35,6 +35,37 @@ while true; do
# Retry
done
+while [ "$TF_NEED_GCP" == "" ]; do
+ read -p "Do you wish to build TensorFlow with "\
+"Google Cloud Platform support? [y/N] " INPUT
+ case $INPUT in
+ [Yy]* ) echo "Google Cloud Platform support will be enabled for "\
+"TensorFlow"; TF_NEED_GCP=1;;
+ [Nn]* ) echo "No Google Cloud Platform support will be enabled for "\
+"TensorFlow"; TF_NEED_GCP=0;;
+ "" ) echo "No Google Cloud Platform support will be enabled for "\
+"TensorFlow"; TF_NEED_GCP=0;;
+ * ) echo "Invalid selection: " $INPUT;;
+ esac
+done
+
+if [ "$TF_NEED_GCP" == "1" ]; then
+
+ ## Verify that libcurl header files are available.
+ # Only check Linux, since on MacOS the header files are installed with XCode.
+ if [[ $(uname -a) =~ Linux ]] && [[ ! -f "/usr/include/curl/curl.h" ]]; then
+ echo "ERROR: It appears that the development version of libcurl is not "\
+"available. Please install the libcurl3-dev package."
+ exit 1
+ fi
+
+ # Update Bazel build configuration.
+ perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = True,s" tensorflow/core/platform/default/build_config.bzl
+else
+ # Update Bazel build configuration.
+ perl -pi -e "s,WITH_GCP_SUPPORT = (False|True),WITH_GCP_SUPPORT = False,s" tensorflow/core/platform/default/build_config.bzl
+fi
+
## Find swig path
if [ -z "$SWIG_PATH" ]; then
SWIG_PATH=`type -p swig 2> /dev/null`
diff --git a/jsoncpp.BUILD b/jsoncpp.BUILD
new file mode 100644
index 0000000000..2bb2e19a67
--- /dev/null
+++ b/jsoncpp.BUILD
@@ -0,0 +1,34 @@
+licenses(["notice"]) # MIT
+
+JSON_HEADERS = [
+ "include/json/assertions.h",
+ "include/json/autolink.h",
+ "include/json/config.h",
+ "include/json/features.h",
+ "include/json/forwards.h",
+ "include/json/json.h",
+ "src/lib_json/json_batchallocator.h",
+ "include/json/reader.h",
+ "include/json/value.h",
+ "include/json/writer.h",
+]
+
+JSON_SOURCES = [
+ "src/lib_json/json_reader.cpp",
+ "src/lib_json/json_value.cpp",
+ "src/lib_json/json_writer.cpp",
+ "src/lib_json/json_tool.h",
+]
+
+INLINE_SOURCES = [
+ "src/lib_json/json_valueiterator.inl",
+]
+
+cc_library(
+ name = "jsoncpp",
+ srcs = JSON_SOURCES,
+ hdrs = JSON_HEADERS,
+ includes = ["include"],
+ textual_hdrs = INLINE_SOURCES,
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 040be96ffa..d50f987009 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -91,6 +91,7 @@ filegroup(
"//tensorflow/core/distributed_runtime/rpc:all_files",
"//tensorflow/core/kernels:all_files",
"//tensorflow/core/ops/compat:all_files",
+ "//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/util/ctc:all_files",
"//tensorflow/examples/android:all_files",
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 73f4e91fa8..044d732f30 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -67,6 +67,7 @@ load(
"tf_proto_library",
"tf_proto_library_cc",
"tf_additional_lib_srcs",
+ "tf_additional_lib_deps",
"tf_additional_stream_executor_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
@@ -995,9 +996,9 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
- "//tensorflow/core/kernels:required",
"//third_party/eigen3",
- ],
+ "//tensorflow/core/kernels:required",
+ ] + tf_additional_lib_deps(),
alwayslink = 1,
)
diff --git a/tensorflow/core/platform/cloud/BUILD b/tensorflow/core/platform/cloud/BUILD
new file mode 100644
index 0000000000..95da137386
--- /dev/null
+++ b/tensorflow/core/platform/cloud/BUILD
@@ -0,0 +1,79 @@
+# Description:
+# Cloud file system implementation.
+
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cc_test",
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+cc_library(
+ name = "gcs_file_system",
+ srcs = [
+ "gcs_file_system.cc",
+ ],
+ hdrs = [
+ "gcs_file_system.h",
+ ],
+ linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
+ visibility = ["//visibility:public"],
+ deps = [
+ "@jsoncpp_git//:jsoncpp",
+ ":http_request",
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_internal",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "http_request",
+ srcs = [
+ "http_request.cc",
+ ],
+ hdrs = [
+ "http_request.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_cc_test(
+ name = "gcs_file_system_test",
+ size = "small",
+ deps = [
+ ":gcs_file_system",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "http_request_test",
+ size = "small",
+ deps = [
+ ":http_request",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
new file mode 100644
index 0000000000..ba9418f06f
--- /dev/null
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -0,0 +1,516 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/gcs_file_system.h"
+#include <stdio.h>
+#include <unistd.h>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <vector>
+#include "include/json/json.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace tensorflow {
+
+namespace {
+
+constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
+constexpr char kGcsUploadUriBase[] =
+ "https://www.googleapis.com/upload/storage/v1/";
+constexpr char kStorageHost[] = "storage.googleapis.com";
+constexpr size_t kBufferSize = 1024 * 1024; // In bytes.
+
+Status GetTmpFilename(string* filename) {
+ if (!filename) {
+ return errors::Internal("'filename' cannot be nullptr.");
+ }
+ char buffer[] = "/tmp/gcs_filesystem_XXXXXX";
+ int fd = mkstemp(buffer);
+ if (fd < 0) {
+ return errors::Internal("Failed to create a temporary file.");
+ }
+ close(fd);
+ *filename = buffer;
+ return Status::OK();
+}
+
+/// No-op auth provider, which will only work for public objects.
+class EmptyAuthProvider : public AuthProvider {
+ public:
+ Status GetToken(string* token) const override {
+ *token = "";
+ return Status::OK();
+ }
+};
+
+Status GetAuthToken(const AuthProvider* provider, string* token) {
+ if (!provider) {
+ return errors::Internal("Auth provider is required.");
+ }
+ return provider->GetToken(token);
+}
+
+/// \brief Splits a GCS path to a bucket and an object.
+///
+/// For example, "gs://bucket-name/path/to/file.txt" gets split into
+/// "bucket-name" and "path/to/file.txt".
+Status ParseGcsPath(const string& fname, string* bucket, string* object) {
+ if (!bucket || !object) {
+ return errors::Internal("bucket and object cannot be null.");
+ }
+ StringPiece matched_bucket, matched_object;
+ if (!strings::Scanner(fname)
+ .OneLiteral("gs://")
+ .RestartCapture()
+ .ScanEscapedUntil('/')
+ .OneLiteral("/")
+ .GetResult(&matched_object, &matched_bucket)) {
+ return errors::InvalidArgument("Couldn't parse GCS path: " + fname);
+ }
+ // 'matched_bucket' contains a trailing slash, exclude it.
+ *bucket = string(matched_bucket.data(), matched_bucket.size() - 1);
+ *object = string(matched_object.data(), matched_object.size());
+ return Status::OK();
+}
+
+/// GCS-based implementation of a random access file.
+class GcsRandomAccessFile : public RandomAccessFile {
+ public:
+ GcsRandomAccessFile(const string& bucket, const string& object,
+ AuthProvider* auth_provider,
+ HttpRequest::Factory* http_request_factory)
+ : bucket_(bucket),
+ object_(object),
+ auth_provider_(auth_provider),
+ http_request_factory_(std::move(http_request_factory)) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override {
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(request->SetUri(
+ strings::StrCat("https://", bucket_, ".", kStorageHost, "/", object_)));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
+ TF_RETURN_IF_ERROR(request->SetResultBuffer(scratch, n, result));
+ TF_RETURN_IF_ERROR(request->Send());
+
+ if (result->size() < n) {
+ // This is not an error per se. The RandomAccessFile interface expects
+ // that Read returns OutOfRange if fewer bytes were read than requested.
+ return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(),
+ " bytes were read out of ", n,
+ " bytes requested."));
+ }
+ return Status::OK();
+ }
+
+ private:
+ string bucket_;
+ string object_;
+ AuthProvider* auth_provider_;
+ HttpRequest::Factory* http_request_factory_;
+};
+
+/// \brief GCS-based implementation of a writeable file.
+///
+/// Since GCS objects are immutable, this implementation writes to a local
+/// tmp file and copies it to GCS on flush/close.
+class GcsWritableFile : public WritableFile {
+ public:
+ GcsWritableFile(const string& bucket, const string& object,
+ AuthProvider* auth_provider,
+ HttpRequest::Factory* http_request_factory)
+ : bucket_(bucket),
+ object_(object),
+ auth_provider_(auth_provider),
+ http_request_factory_(std::move(http_request_factory)) {
+ if (GetTmpFilename(&tmp_content_filename_).ok()) {
+ outfile_.open(tmp_content_filename_,
+ std::ofstream::binary | std::ofstream::app);
+ }
+ }
+
+ /// \brief Constructs the writable file in append mode.
+ ///
+ /// tmp_content_filename should contain a path of an existing temporary file
+ /// with the content to be appended. The class takes onwnership of the
+ /// specified tmp file and deletes it on close.
+ GcsWritableFile(const string& bucket, const string& object,
+ AuthProvider* auth_provider,
+ const string& tmp_content_filename,
+ HttpRequest::Factory* http_request_factory)
+ : bucket_(bucket),
+ object_(object),
+ auth_provider_(auth_provider),
+ http_request_factory_(std::move(http_request_factory)) {
+ tmp_content_filename_ = tmp_content_filename;
+ outfile_.open(tmp_content_filename_,
+ std::ofstream::binary | std::ofstream::app);
+ }
+
+ ~GcsWritableFile() { Close(); }
+
+ Status Append(const StringPiece& data) override {
+ TF_RETURN_IF_ERROR(CheckWritable());
+ outfile_ << data;
+ return Status::OK();
+ }
+
+ Status Close() override {
+ if (outfile_.is_open()) {
+ TF_RETURN_IF_ERROR(Sync());
+ outfile_.close();
+ std::remove(tmp_content_filename_.c_str());
+ }
+ return Status::OK();
+ }
+
+ Status Flush() override { return Sync(); }
+
+ /// Copies the current version of the file to GCS.
+ Status Sync() override {
+ TF_RETURN_IF_ERROR(CheckWritable());
+ outfile_.flush();
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_, &auth_token));
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(
+ request->SetUri(strings::StrCat(kGcsUploadUriBase, "b/", bucket_,
+ "/o?uploadType=media&name=", object_)));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetPostRequest(tmp_content_filename_));
+ TF_RETURN_IF_ERROR(request->Send());
+ return Status::OK();
+ }
+
+ private:
+ Status CheckWritable() const {
+ if (!outfile_.is_open()) {
+ return errors::FailedPrecondition(
+ "The underlying tmp file is not writable.");
+ }
+ return Status::OK();
+ }
+
+ string bucket_;
+ string object_;
+ AuthProvider* auth_provider_;
+ string tmp_content_filename_;
+ std::ofstream outfile_;
+ HttpRequest::Factory* http_request_factory_;
+};
+
+class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
+ public:
+ GcsReadOnlyMemoryRegion(std::unique_ptr<char[]> data, uint64 length)
+ : data_(std::move(data)), length_(length) {}
+ const void* data() override { return reinterpret_cast<void*>(data_.get()); }
+ uint64 length() override { return length_; }
+
+ private:
+ std::unique_ptr<char[]> data_;
+ uint64 length_;
+};
+} // namespace
+
+GcsFileSystem::GcsFileSystem()
+ : auth_provider_(new EmptyAuthProvider()),
+ http_request_factory_(new HttpRequest::Factory()) {}
+
+GcsFileSystem::GcsFileSystem(
+ std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory)
+ : auth_provider_(std::move(auth_provider)),
+ http_request_factory_(std::move(http_request_factory)) {}
+
+Status GcsFileSystem::NewRandomAccessFile(const string& fname,
+ RandomAccessFile** result) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
+ *result = new GcsRandomAccessFile(bucket, object, auth_provider_.get(),
+ http_request_factory_.get());
+ return Status::OK();
+}
+
+Status GcsFileSystem::NewWritableFile(const string& fname,
+ WritableFile** result) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
+ *result = new GcsWritableFile(bucket, object, auth_provider_.get(),
+ http_request_factory_.get());
+ return Status::OK();
+}
+
+// Reads the file from GCS in chunks and stores it in a tmp file,
+// which is then passed to GcsWritableFile.
+Status GcsFileSystem::NewAppendableFile(const string& fname,
+ WritableFile** result) {
+ RandomAccessFile* reader_ptr;
+ TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &reader_ptr));
+ std::unique_ptr<RandomAccessFile> reader(reader_ptr);
+ std::unique_ptr<char[]> buffer(new char[kBufferSize]);
+ Status status;
+ uint64 offset = 0;
+ StringPiece read_chunk;
+
+ // Read the file from GCS in chunks and save it to a tmp file.
+ string old_content_filename;
+ TF_RETURN_IF_ERROR(GetTmpFilename(&old_content_filename));
+ std::ofstream old_content(old_content_filename, std::ofstream::binary);
+ while (true) {
+ status = reader->Read(offset, kBufferSize, &read_chunk, buffer.get());
+ if (status.ok()) {
+ old_content << read_chunk;
+ offset += kBufferSize;
+ } else if (status.code() == error::OUT_OF_RANGE) {
+ // Expected, this means we reached EOF.
+ old_content << read_chunk;
+ break;
+ } else {
+ return status;
+ }
+ }
+ old_content.close();
+
+ // Create a writable file and pass the old content to it.
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
+ *result =
+ new GcsWritableFile(bucket, object, auth_provider_.get(),
+ old_content_filename, http_request_factory_.get());
+ return Status::OK();
+}
+
+Status GcsFileSystem::NewReadOnlyMemoryRegionFromFile(
+ const string& fname, ReadOnlyMemoryRegion** result) {
+ uint64 size;
+ TF_RETURN_IF_ERROR(GetFileSize(fname, &size));
+ std::unique_ptr<char[]> data(new char[size]);
+
+ RandomAccessFile* file;
+ TF_RETURN_IF_ERROR(NewRandomAccessFile(fname, &file));
+ std::unique_ptr<RandomAccessFile> file_ptr(file);
+
+ StringPiece piece;
+ TF_RETURN_IF_ERROR(file->Read(0, size, &piece, data.get()));
+
+ *result = new GcsReadOnlyMemoryRegion(std::move(data), size);
+ return Status::OK();
+}
+
+bool GcsFileSystem::FileExists(const string& fname) {
+ string bucket, object_prefix;
+ if (!ParseGcsPath(fname, &bucket, &object_prefix).ok()) {
+ LOG(ERROR) << "Could not parse GCS file name " << fname;
+ return false;
+ }
+
+ string auth_token;
+ if (!GetAuthToken(auth_provider_.get(), &auth_token).ok()) {
+ LOG(ERROR) << "Could not get an auth token.";
+ return false;
+ }
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ if (!request->Init().ok()) {
+ LOG(ERROR) << "Could not initialize the HTTP request.";
+ return false;
+ }
+ request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o/",
+ object_prefix, "?fields=size"));
+ request->AddAuthBearerHeader(auth_token);
+ return request->Send().ok();
+}
+
+Status GcsFileSystem::GetChildren(const string& dirname,
+ std::vector<string>* result) {
+ if (!result) {
+ return errors::InvalidArgument("'result' cannot be null");
+ }
+ string sanitized_dirname = dirname;
+ if (!dirname.empty() && dirname.back() != '/') {
+ sanitized_dirname += "/";
+ }
+ string bucket, object_prefix;
+ TF_RETURN_IF_ERROR(ParseGcsPath(sanitized_dirname, &bucket, &object_prefix));
+
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
+
+ std::unique_ptr<char[]> scratch(new char[kBufferSize]);
+ StringPiece response_piece;
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(
+ request->SetUri(strings::StrCat(kGcsUriBase, "b/", bucket, "/o?prefix=",
+ object_prefix, "&fields=items")));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ // TODO(surkov): Implement pagination using maxResults and pageToken
+ // instead, so that all items can be read regardless of their count.
+ // Currently one item takes about 1KB in the response, so with a 1MB
+ // buffer size this will read fewer than 1000 objects.
+ TF_RETURN_IF_ERROR(
+ request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece));
+ TF_RETURN_IF_ERROR(request->Send());
+ std::stringstream response_stream;
+ response_stream << response_piece;
+ Json::Value root;
+ Json::Reader reader;
+ if (!reader.parse(response_stream.str(), root)) {
+ return errors::Internal("Couldn't parse JSON response from GCS.");
+ }
+ const auto items = root.get("items", Json::Value::null);
+ if (items == Json::Value::null) {
+ // Empty results.
+ return Status::OK();
+ }
+ if (!items.isArray()) {
+ return errors::Internal("Expected an array 'items' in the GCS response.");
+ }
+ for (size_t i = 0; i < items.size(); i++) {
+ const auto item = items.get(i, Json::Value::null);
+ if (!item.isObject()) {
+ return errors::Internal(
+ "Unexpected JSON format: 'items' should be a list of objects.");
+ }
+ const auto name = item.get("name", Json::Value::null);
+ if (name == Json::Value::null || !name.isString()) {
+ return errors::Internal(
+ "Unexpected JSON format: 'items.name' is missing or not a string.");
+ }
+ result->push_back(
+ strings::StrCat("gs://", bucket, "/", name.asString().c_str()));
+ }
+ return Status::OK();
+}
+
+Status GcsFileSystem::DeleteFile(const string& fname) {
+ string bucket, object;
+ TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
+
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(request->SetUri(
+ strings::StrCat(kGcsUriBase, "b/", bucket, "/o/", object)));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetDeleteRequest());
+ TF_RETURN_IF_ERROR(request->Send());
+ return Status::OK();
+}
+
+// Does nothing, because directories are not entities in GCS.
+Status GcsFileSystem::CreateDir(const string& dirname) { return Status::OK(); }
+
+// Checks that the directory is empty (i.e no objects with this prefix exist).
+// If it is, does nothing, because directories are not entities in GCS.
+Status GcsFileSystem::DeleteDir(const string& dirname) {
+ string sanitized_dirname = dirname;
+ if (!dirname.empty() && dirname.back() != '/') {
+ sanitized_dirname += "/";
+ }
+ std::vector<string> children;
+ TF_RETURN_IF_ERROR(GetChildren(sanitized_dirname, &children));
+ if (!children.empty()) {
+ return errors::InvalidArgument("Cannot delete a non-empty directory.");
+ }
+ return Status::OK();
+}
+
+Status GcsFileSystem::GetFileSize(const string& fname, uint64* file_size) {
+ string bucket, object_prefix;
+ TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object_prefix));
+
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
+
+ std::unique_ptr<char[]> scratch(new char[kBufferSize]);
+ StringPiece response_piece;
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat(
+ kGcsUriBase, "b/", bucket, "/o/", object_prefix, "?fields=size")));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(
+ request->SetResultBuffer(scratch.get(), kBufferSize, &response_piece));
+ TF_RETURN_IF_ERROR(request->Send());
+ std::stringstream response_stream;
+ response_stream << response_piece;
+
+ Json::Value root;
+ Json::Reader reader;
+ if (!reader.parse(response_stream.str(), root)) {
+ return errors::Internal("Couldn't parse JSON response from GCS.");
+ }
+ const auto size = root.get("size", Json::Value::null);
+ if (size == Json::Value::null) {
+ return errors::Internal("'size' was expected in the JSON response.");
+ }
+ if (size.isNumeric()) {
+ *file_size = size.asUInt64();
+ } else if (size.isString()) {
+ if (!strings::safe_strtou64(size.asString().c_str(), file_size)) {
+ return errors::Internal("'size' couldn't be parsed as a nubmer.");
+ }
+ } else {
+ return errors::Internal("'size' is not a number in the JSON response.");
+ }
+ return Status::OK();
+}
+
+// Uses a GCS API command to copy the object and then deletes the old one.
+Status GcsFileSystem::RenameFile(const string& src, const string& target) {
+ string src_bucket, src_object, target_bucket, target_object;
+ TF_RETURN_IF_ERROR(ParseGcsPath(src, &src_bucket, &src_object));
+ TF_RETURN_IF_ERROR(ParseGcsPath(target, &target_bucket, &target_object));
+
+ string auth_token;
+ TF_RETURN_IF_ERROR(GetAuthToken(auth_provider_.get(), &auth_token));
+
+ std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
+ TF_RETURN_IF_ERROR(request->Init());
+ TF_RETURN_IF_ERROR(request->SetUri(
+ strings::StrCat(kGcsUriBase, "b/", src_bucket, "/o/", src_object,
+ "/rewriteTo/b/", target_bucket, "/o/", target_object)));
+ TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token));
+ TF_RETURN_IF_ERROR(request->SetPostRequest());
+ TF_RETURN_IF_ERROR(request->Send());
+
+ TF_RETURN_IF_ERROR(DeleteFile(src));
+ return Status::OK();
+}
+
+REGISTER_FILE_SYSTEM("gs", GcsFileSystem);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
new file mode 100644
index 0000000000..47c22173de
--- /dev/null
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -0,0 +1,73 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
+#define TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
+
+#include <string>
+#include <vector>
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/cloud/http_request.h"
+#include "tensorflow/core/platform/file_system.h"
+
+namespace tensorflow {
+
+/// Interface for a provider of HTTP auth bearer tokens.
+class AuthProvider {
+ public:
+ virtual ~AuthProvider() {}
+ virtual Status GetToken(string* t) const = 0;
+};
+
+/// Google Cloud Storage implementation of a file system.
+class GcsFileSystem : public FileSystem {
+ public:
+ GcsFileSystem();
+ GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,
+ std::unique_ptr<HttpRequest::Factory> http_request_factory);
+
+ Status NewRandomAccessFile(const string& fname,
+ RandomAccessFile** result) override;
+
+ Status NewWritableFile(const string& fname, WritableFile** result) override;
+
+ Status NewAppendableFile(const string& fname, WritableFile** result) override;
+
+ Status NewReadOnlyMemoryRegionFromFile(
+ const string& fname, ReadOnlyMemoryRegion** result) override;
+
+ bool FileExists(const string& fname) override;
+
+ Status GetChildren(const string& dir, std::vector<string>* result) override;
+
+ Status DeleteFile(const string& fname) override;
+
+ Status CreateDir(const string& dirname) override;
+
+ Status DeleteDir(const string& dirname) override;
+
+ Status GetFileSize(const string& fname, uint64* file_size) override;
+
+ Status RenameFile(const string& src, const string& target) override;
+
+ private:
+ std::unique_ptr<AuthProvider> auth_provider_;
+ std::unique_ptr<HttpRequest::Factory> http_request_factory_;
+ TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_GCS_FILE_SYSTEM_H_
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
new file mode 100644
index 0000000000..151aebd87c
--- /dev/null
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -0,0 +1,363 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/gcs_file_system.h"
+#include <fstream>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class FakeHttpRequest : public HttpRequest {
+ public:
+ FakeHttpRequest(const string& request, const string& response)
+ : FakeHttpRequest(request, response, Status::OK()) {}
+
+ FakeHttpRequest(const string& request, const string& response,
+ Status response_status)
+ : expected_request_(request),
+ response_(response),
+ response_status_(response_status) {}
+
+ Status Init() override { return Status::OK(); }
+ Status SetUri(const string& uri) override {
+ actual_request_ += "Uri: " + uri + "\n";
+ return Status::OK();
+ }
+ Status SetRange(uint64 start, uint64 end) override {
+ actual_request_ += strings::StrCat("Range: ", start, "-", end, "\n");
+ return Status::OK();
+ }
+ Status AddAuthBearerHeader(const string& auth_token) override {
+ actual_request_ += "Auth Token: " + auth_token + "\n";
+ return Status::OK();
+ }
+ Status SetDeleteRequest() override {
+ actual_request_ += "Delete: yes\n";
+ return Status::OK();
+ }
+ Status SetPostRequest(const string& body_filepath) override {
+ std::ifstream stream(body_filepath);
+ string content((std::istreambuf_iterator<char>(stream)),
+ std::istreambuf_iterator<char>());
+ actual_request_ += "Post body: " + content + "\n";
+ return Status::OK();
+ }
+ Status SetPostRequest() override {
+ actual_request_ += "Post: yes\n";
+ return Status::OK();
+ }
+ Status SetResultBuffer(char* scratch, size_t size,
+ StringPiece* result) override {
+ scratch_ = scratch;
+ size_ = size;
+ result_ = result;
+ return Status::OK();
+ }
+ Status Send() override {
+ EXPECT_EQ(expected_request_, actual_request_) << "Unexpected HTTP request.";
+ if (scratch_ && result_) {
+ auto actual_size = std::min(response_.size(), size_);
+ memcpy(scratch_, response_.c_str(), actual_size);
+ *result_ = StringPiece(scratch_, actual_size);
+ }
+ return response_status_;
+ }
+
+ private:
+ char* scratch_ = nullptr;
+ size_t size_ = 0;
+ StringPiece* result_ = nullptr;
+ string expected_request_;
+ string actual_request_;
+ string response_;
+ Status response_status_;
+};
+
+class FakeHttpRequestFactory : public HttpRequest::Factory {
+ public:
+ FakeHttpRequestFactory(const std::vector<HttpRequest*>* requests)
+ : requests_(requests) {}
+
+ ~FakeHttpRequestFactory() {
+ EXPECT_EQ(current_index_, requests_->size())
+ << "Not all expected requests were made.";
+ }
+
+ HttpRequest* Create() override {
+ EXPECT_LT(current_index_, requests_->size())
+ << "Too many calls of HttpRequest factory.";
+ return (*requests_)[current_index_++];
+ }
+
+ private:
+ const std::vector<HttpRequest*>* requests_;
+ int current_index_ = 0;
+};
+
+class FakeAuthProvider : public AuthProvider {
+ public:
+ Status GetToken(string* token) const override {
+ *token = "fake_token";
+ return Status::OK();
+ }
+};
+
+TEST(GcsFileSystemTest, NewRandomAccessFile) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-5\n",
+ "012345"),
+ new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 6-11\n",
+ "6789")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ RandomAccessFile* file_ptr;
+ TF_EXPECT_OK(
+ fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file_ptr));
+ std::unique_ptr<RandomAccessFile> file(file_ptr);
+
+ char scratch[6];
+ StringPiece result;
+
+ // Read the first chunk.
+ TF_EXPECT_OK(file->Read(0, sizeof(scratch), &result, scratch));
+ EXPECT_EQ("012345", result);
+
+ // Read the second chunk.
+ EXPECT_EQ(
+ errors::Code::OUT_OF_RANGE,
+ file->Read(sizeof(scratch), sizeof(scratch), &result, scratch).code());
+ EXPECT_EQ("6789", result);
+}
+
+TEST(GcsFileSystemTest, NewWritableFile) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?"
+ "uploadType=media&name=path/writeable.txt\n"
+ "Auth Token: fake_token\n"
+ "Post body: content1,content2\n",
+ "")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ WritableFile* file_ptr;
+ TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file_ptr));
+ std::unique_ptr<WritableFile> file(file_ptr);
+
+ TF_EXPECT_OK(file->Append("content1,"));
+ TF_EXPECT_OK(file->Append("content2"));
+ TF_EXPECT_OK(file->Close());
+}
+
+TEST(GcsFileSystemTest, NewAppendableFile) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/path/appendable.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-1048575\n",
+ "content1,"),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?"
+ "uploadType=media&name=path/appendable.txt\n"
+ "Auth Token: fake_token\n"
+ "Post body: content1,content2\n",
+ "")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ WritableFile* file_ptr;
+ TF_EXPECT_OK(
+ fs.NewAppendableFile("gs://bucket/path/appendable.txt", &file_ptr));
+ std::unique_ptr<WritableFile> file(file_ptr);
+
+ TF_EXPECT_OK(file->Append("content2"));
+ TF_EXPECT_OK(file->Close());
+}
+
+TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
+ const string content = "file content";
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "random_access.txt?fields=size\n"
+ "Auth Token: fake_token\n",
+ strings::StrCat("{\"size\": \"", content.size(), "\"}")),
+ new FakeHttpRequest(
+ strings::StrCat(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-",
+ content.size() - 1, "\n"),
+ content)});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ ReadOnlyMemoryRegion* region_ptr;
+ TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
+ "gs://bucket/random_access.txt", &region_ptr));
+ std::unique_ptr<ReadOnlyMemoryRegion> region(region_ptr);
+
+ EXPECT_EQ(content, StringPiece(reinterpret_cast<const char*>(region->data()),
+ region->length()));
+}
+
+TEST(GcsFileSystemTest, FileExists) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "path/file1.txt?fields=size\n"
+ "Auth Token: fake_token\n",
+ "{\"size\": \"100\"}"),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "path/file2.txt?fields=size\n"
+ "Auth Token: fake_token\n",
+ "", errors::NotFound("404"))});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ EXPECT_TRUE(fs.FileExists("gs://bucket/path/file1.txt"));
+ EXPECT_FALSE(fs.FileExists("gs://bucket/path/file2.txt"));
+}
+
+TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "prefix=path/&fields=items\n"
+ "Auth Token: fake_token\n",
+ "{\"items\": [ "
+ " { \"name\": \"path/file1.txt\" },"
+ " { \"name\": \"path/subpath/file2.txt\" },"
+ " { \"name\": \"path/file3.txt\" }]}")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ std::vector<string> children;
+ TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
+
+ EXPECT_EQ(3, children.size());
+ EXPECT_EQ("gs://bucket/path/file1.txt", children[0]);
+ EXPECT_EQ("gs://bucket/path/subpath/file2.txt", children[1]);
+ EXPECT_EQ("gs://bucket/path/file3.txt", children[2]);
+}
+
+TEST(GcsFileSystemTest, GetChildren_Empty) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "prefix=path/&fields=items\n"
+ "Auth Token: fake_token\n",
+ "{}")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ std::vector<string> children;
+ TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
+
+ EXPECT_EQ(0, children.size());
+}
+
+TEST(GcsFileSystemTest, DeleteFile) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b"
+ "/bucket/o/path/file1.txt\n"
+ "Auth Token: fake_token\n"
+ "Delete: yes\n",
+ "")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ TF_EXPECT_OK(fs.DeleteFile("gs://bucket/path/file1.txt"));
+}
+
+TEST(GcsFileSystemTest, DeleteDir_Empty) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "prefix=path/&fields=items\n"
+ "Auth Token: fake_token\n",
+ "{}")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
+}
+
+TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "prefix=path/&fields=items\n"
+ "Auth Token: fake_token\n",
+ "{\"items\": [ "
+ " { \"name\": \"path/file1.txt\" }]}")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ EXPECT_FALSE(fs.DeleteDir("gs://bucket/path/").ok());
+}
+
+TEST(GcsFileSystemTest, GetFileSize) {
+ std::vector<HttpRequest*> requests({new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size\n"
+ "Auth Token: fake_token\n",
+ strings::StrCat("{\"size\": \"1010\"}"))});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ uint64 size;
+ TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
+ EXPECT_EQ(1010, size);
+}
+
+TEST(GcsFileSystemTest, RenameFile) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt"
+ "/rewriteTo/b/bucket/o/dst.txt\n"
+ "Auth Token: fake_token\n"
+ "Post: yes\n",
+ ""),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/src.txt\n"
+ "Auth Token: fake_token\n"
+ "Delete: yes\n",
+ "")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)));
+
+ TF_EXPECT_OK(fs.RenameFile("gs://bucket/src.txt", "gs://bucket/dst.txt"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/http_request.cc b/tensorflow/core/platform/cloud/http_request.cc
new file mode 100644
index 0000000000..38f132e723
--- /dev/null
+++ b/tensorflow/core/platform/cloud/http_request.cc
@@ -0,0 +1,410 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/http_request.h"
+#include <dlfcn.h>
+#include <stdio.h>
+#include <cstdio>
+#include <cstdlib>
+#include <cstring>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Windows is not currently supported.
+constexpr char kCurlLibLinux[] = "libcurl.so.3";
+constexpr char kCurlLibMac[] = "/usr/lib/libcurl.3.dylib";
+
+constexpr char kCertsPath[] = "/etc/ssl/certs";
+
+// Set to 1 to enable verbose debug output from curl.
+constexpr uint64 kVerboseOutput = 0;
+
+/// An implementation that dynamically loads libcurl and forwards calls to it.
+class LibCurlProxy : public LibCurl {
+ public:
+ ~LibCurlProxy() {
+ if (dll_handle_) {
+ dlclose(dll_handle_);
+ }
+ }
+
+ Status MaybeLoadDll() override {
+ if (dll_handle_) {
+ return Status::OK();
+ }
+ // This may have been linked statically; if curl_easy_init is in the
+ // current binary, no need to search for a dynamic version.
+ dll_handle_ = load_dll(nullptr);
+ if (!dll_handle_) {
+ dll_handle_ = load_dll(kCurlLibLinux);
+ }
+ if (!dll_handle_) {
+ dll_handle_ = load_dll(kCurlLibMac);
+ }
+ if (!dll_handle_) {
+ return errors::FailedPrecondition(strings::StrCat(
+ "Could not initialize the libcurl library. Please make sure that "
+ "libcurl is installed in the OS or statically linked to the "
+ "TensorFlow binary."));
+ }
+ curl_global_init_(CURL_GLOBAL_ALL);
+ return Status::OK();
+ }
+
+ CURL* curl_easy_init() override {
+ CHECK(dll_handle_);
+ return curl_easy_init_();
+ }
+
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ uint64 param) override {
+ CHECK(dll_handle_);
+ return curl_easy_setopt_(curl, option, param);
+ }
+
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ const char* param) override {
+ CHECK(dll_handle_);
+ return curl_easy_setopt_(curl, option, param);
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ void* param) override {
+ CHECK(dll_handle_);
+ return curl_easy_setopt_(curl, option, param);
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(void*, size_t, size_t,
+ FILE*)) override {
+ CHECK(dll_handle_);
+ return curl_easy_setopt_(curl, option, param);
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(const void*, size_t, size_t,
+ void*)) override {
+ CHECK(dll_handle_);
+ return curl_easy_setopt_(curl, option, param);
+ }
+
+ CURLcode curl_easy_perform(CURL* curl) override {
+ CHECK(dll_handle_);
+ return curl_easy_perform_(curl);
+ }
+
+ CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ uint64* value) override {
+ CHECK(dll_handle_);
+ return curl_easy_getinfo_(curl, info, value);
+ }
+ CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ double* value) override {
+ CHECK(dll_handle_);
+ return curl_easy_getinfo_(curl, info, value);
+ }
+ void curl_easy_cleanup(CURL* curl) override {
+ CHECK(dll_handle_);
+ return curl_easy_cleanup_(curl);
+ }
+
+ curl_slist* curl_slist_append(curl_slist* list, const char* str) override {
+ CHECK(dll_handle_);
+ return curl_slist_append_(list, str);
+ }
+
+ void curl_slist_free_all(curl_slist* list) override {
+ CHECK(dll_handle_);
+ return curl_slist_free_all_(list);
+ }
+
+ private:
+ // Loads the dynamic library and binds the required methods.
+ // Returns the library handle in case of success or nullptr otherwise.
+ // 'name' can be nullptr.
+ void* load_dll(const char* name) {
+ void* handle = nullptr;
+ handle = dlopen(name, RTLD_NOW | RTLD_LOCAL | RTLD_NODELETE);
+ if (!handle) {
+ return nullptr;
+ }
+
+#define BIND_CURL_FUNC(function) \
+ *reinterpret_cast<void**>(&(function##_)) = dlsym(handle, #function)
+
+ BIND_CURL_FUNC(curl_global_init);
+ BIND_CURL_FUNC(curl_easy_init);
+ BIND_CURL_FUNC(curl_easy_setopt);
+ BIND_CURL_FUNC(curl_easy_perform);
+ BIND_CURL_FUNC(curl_easy_getinfo);
+ BIND_CURL_FUNC(curl_slist_append);
+ BIND_CURL_FUNC(curl_slist_free_all);
+ BIND_CURL_FUNC(curl_easy_cleanup);
+
+#undef BIND_CURL_FUNC
+
+ if (curl_global_init_ == nullptr) {
+ dlerror(); // Clear dlerror before attempting to open libraries.
+ dlclose(handle);
+ return nullptr;
+ }
+ return handle;
+ }
+
+ void* dll_handle_ = nullptr;
+ CURLcode (*curl_global_init_)(int64) = nullptr;
+ CURL* (*curl_easy_init_)(void) = nullptr;
+ CURLcode (*curl_easy_setopt_)(CURL*, CURLoption, ...) = nullptr;
+ CURLcode (*curl_easy_perform_)(CURL* curl) = nullptr;
+ CURLcode (*curl_easy_getinfo_)(CURL* curl, CURLINFO info, ...) = nullptr;
+ void (*curl_easy_cleanup_)(CURL* curl) = nullptr;
+ curl_slist* (*curl_slist_append_)(curl_slist* list,
+ const char* str) = nullptr;
+ void (*curl_slist_free_all_)(curl_slist* list) = nullptr;
+};
+} // namespace
+
+HttpRequest::HttpRequest()
+ : HttpRequest(std::unique_ptr<LibCurl>(new LibCurlProxy)) {}
+
+HttpRequest::HttpRequest(std::unique_ptr<LibCurl> libcurl)
+ : libcurl_(std::move(libcurl)),
+ default_response_buffer_(new char[CURL_MAX_WRITE_SIZE]) {}
+
+HttpRequest::~HttpRequest() {
+ if (curl_headers_) {
+ libcurl_->curl_slist_free_all(curl_headers_);
+ }
+ if (post_body_) {
+ fclose(post_body_);
+ }
+ if (curl_) {
+ libcurl_->curl_easy_cleanup(curl_);
+ }
+}
+
+Status HttpRequest::Init() {
+ if (!libcurl_) {
+ return errors::Internal("libcurl proxy cannot be nullptr.");
+ }
+ TF_RETURN_IF_ERROR(libcurl_->MaybeLoadDll());
+ curl_ = libcurl_->curl_easy_init();
+ if (!curl_) {
+ return errors::Internal("Couldn't initialize a curl session.");
+ }
+
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_VERBOSE, kVerboseOutput);
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_CAPATH, kCertsPath);
+
+ // If response buffer is not set, libcurl will print results to stdout,
+ // so we always set it.
+ is_initialized_ = true;
+ auto s = SetResultBuffer(default_response_buffer_.get(), CURL_MAX_WRITE_SIZE,
+ &default_response_string_piece_);
+ if (!s.ok()) {
+ is_initialized_ = false;
+ return s;
+ }
+ return Status::OK();
+}
+
+Status HttpRequest::SetUri(const string& uri) {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ is_uri_set_ = true;
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_URL, uri.c_str());
+ return Status::OK();
+}
+
+Status HttpRequest::SetRange(uint64 start, uint64 end) {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_RANGE,
+ strings::StrCat(start, "-", end).c_str());
+ return Status::OK();
+}
+
+Status HttpRequest::AddAuthBearerHeader(const string& auth_token) {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ if (!auth_token.empty()) {
+ curl_headers_ = libcurl_->curl_slist_append(
+ curl_headers_,
+ strings::StrCat("Authorization: Bearer ", auth_token).c_str());
+ }
+ return Status::OK();
+}
+
+Status HttpRequest::SetDeleteRequest() {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ TF_RETURN_IF_ERROR(CheckMethodNotSet());
+ is_method_set_ = true;
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_CUSTOMREQUEST, "DELETE");
+ return Status::OK();
+}
+
+Status HttpRequest::SetPostRequest(const string& body_filepath) {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ TF_RETURN_IF_ERROR(CheckMethodNotSet());
+ is_method_set_ = true;
+ if (post_body_) {
+ fclose(post_body_);
+ }
+ post_body_ = fopen(body_filepath.c_str(), "r");
+ if (!post_body_) {
+ return errors::InvalidArgument("Couldnt' open the specified file: " +
+ body_filepath);
+ }
+ fseek(post_body_, 0, SEEK_END);
+ const auto size = ftell(post_body_);
+ fseek(post_body_, 0, SEEK_SET);
+
+ curl_headers_ = libcurl_->curl_slist_append(
+ curl_headers_, strings::StrCat("Content-Length: ", size).c_str());
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1);
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_READDATA,
+ reinterpret_cast<void*>(post_body_));
+ return Status::OK();
+}
+
+Status HttpRequest::SetPostRequest() {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ TF_RETURN_IF_ERROR(CheckMethodNotSet());
+ is_method_set_ = true;
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_POST, 1);
+ curl_headers_ =
+ libcurl_->curl_slist_append(curl_headers_, "Content-Length: 0");
+ return Status::OK();
+}
+
+Status HttpRequest::SetResultBuffer(char* scratch, size_t size,
+ StringPiece* result) {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ if (!scratch) {
+ return errors::InvalidArgument("scratch cannot be null");
+ }
+ if (!result) {
+ return errors::InvalidArgument("result cannot be null");
+ }
+ if (size <= 0) {
+ return errors::InvalidArgument("buffer size should be positive");
+ }
+
+ response_buffer_ = scratch;
+ response_buffer_size_ = size;
+ response_string_piece_ = result;
+ response_buffer_written_ = 0;
+
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
+ reinterpret_cast<void*>(this));
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
+ &HttpRequest::WriteCallback);
+ return Status::OK();
+}
+
+size_t HttpRequest::WriteCallback(const void* ptr, size_t size, size_t nmemb,
+ void* this_object) {
+ CHECK(ptr);
+ auto that = reinterpret_cast<HttpRequest*>(this_object);
+ CHECK(that->response_buffer_);
+ CHECK(that->response_buffer_size_ >= that->response_buffer_written_);
+ const size_t bytes_to_copy =
+ std::min(size * nmemb,
+ that->response_buffer_size_ - that->response_buffer_written_);
+ memcpy(that->response_buffer_ + that->response_buffer_written_, ptr,
+ bytes_to_copy);
+ that->response_buffer_written_ += bytes_to_copy;
+ return bytes_to_copy;
+}
+
+Status HttpRequest::Send() {
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+ is_sent_ = true;
+ if (!is_uri_set_) {
+ return errors::FailedPrecondition("URI has not been set.");
+ }
+ if (curl_headers_) {
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTPHEADER, curl_headers_);
+ }
+
+ char error_buffer[CURL_ERROR_SIZE];
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_ERRORBUFFER, error_buffer);
+
+ const auto curl_result = libcurl_->curl_easy_perform(curl_);
+
+ double written_size = 0;
+ libcurl_->curl_easy_getinfo(curl_, CURLINFO_SIZE_DOWNLOAD, &written_size);
+
+ uint64 response_code;
+ libcurl_->curl_easy_getinfo(curl_, CURLINFO_RESPONSE_CODE, &response_code);
+
+ if (curl_result != CURLE_OK) {
+ return errors::Internal(string("curl error: ") + error_buffer);
+ }
+ switch (response_code) {
+ case 200: // OK
+ case 204: // No Content
+ case 206: // Partial Content
+ if (response_buffer_ && response_string_piece_) {
+ *response_string_piece_ = StringPiece(response_buffer_, written_size);
+ }
+ return Status::OK();
+ case 401:
+ return errors::PermissionDenied(
+ "Not authorized to access the given HTTP resource.");
+ case 404:
+ return errors::NotFound("The requested URL was not found.");
+ case 416: // Requested Range Not Satisfiable
+ if (response_string_piece_) {
+ *response_string_piece_ = StringPiece();
+ }
+ return Status::OK();
+ default:
+ return errors::Internal(
+ strings::StrCat("Unexpected HTTP response code ", response_code));
+ }
+}
+
+Status HttpRequest::CheckInitialized() const {
+ if (!is_initialized_) {
+ return errors::FailedPrecondition("The object has not been initialized.");
+ }
+ return Status::OK();
+}
+
+Status HttpRequest::CheckMethodNotSet() const {
+ if (is_method_set_) {
+ return errors::FailedPrecondition("HTTP method has been already set.");
+ }
+ return Status::OK();
+}
+
+Status HttpRequest::CheckNotSent() const {
+ if (is_sent_) {
+ return errors::FailedPrecondition("The request has already been sent.");
+ }
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
new file mode 100644
index 0000000000..19aed67e6a
--- /dev/null
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -0,0 +1,156 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include <curl/curl.h>
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+class LibCurl; // libcurl interface as a class, for dependency injection.
+
+/// \brief A basic HTTP client based on the libcurl library.
+///
+/// The usage pattern for the class reflects the one of the libcurl library:
+/// create a request object, set request parameters and call Send().
+///
+/// For example:
+/// HttpRequest request;
+/// request.SetUri("http://www.google.com");
+/// request.SetResultsBuffer(scratch, 1000, &result);
+/// request.Send();
+class HttpRequest {
+ public:
+ class Factory {
+ public:
+ virtual ~Factory() {}
+ virtual HttpRequest* Create() { return new HttpRequest(); }
+ };
+
+ HttpRequest();
+ explicit HttpRequest(std::unique_ptr<LibCurl> libcurl);
+ virtual ~HttpRequest();
+
+ virtual Status Init();
+
+ /// Sets the request URI.
+ virtual Status SetUri(const string& uri);
+
+ /// \brief Sets the Range header.
+ ///
+ /// Used for random seeks, for example "0-999" returns the first 1000 bytes
+ /// (note that the right border is included).
+ virtual Status SetRange(uint64 start, uint64 end);
+
+ /// Sets the 'Authorization' header to the value of 'Bearer ' + auth_token.
+ virtual Status AddAuthBearerHeader(const string& auth_token);
+
+ /// Makes the request a DELETE request.
+ virtual Status SetDeleteRequest();
+
+ /// \brief Makes the request a POST request.
+ ///
+ /// The request body will be taken from the specified file.
+ virtual Status SetPostRequest(const string& body_filepath);
+
+ /// Makes the request a POST request.
+ virtual Status SetPostRequest();
+
+ /// \brief Specifies the buffer for receiving the response body.
+ ///
+ /// The interface is made similar to RandomAccessFile::Read.
+ virtual Status SetResultBuffer(char* scratch, size_t size,
+ StringPiece* result);
+
+ /// \brief Sends the formed request.
+ ///
+ /// If the result buffer was defined, the response will be written there.
+ /// The object is not designed to be re-used after Send() is executed.
+ virtual Status Send();
+
+ private:
+ /// A callback in the form which can be accepted by libcurl.
+ static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
+ void* userdata);
+ Status CheckInitialized() const;
+ Status CheckMethodNotSet() const;
+ Status CheckNotSent() const;
+
+ std::unique_ptr<LibCurl> libcurl_;
+ FILE* post_body_ = nullptr;
+ char* response_buffer_ = nullptr;
+ size_t response_buffer_size_ = 0;
+ size_t response_buffer_written_ = 0;
+ StringPiece* response_string_piece_ = nullptr;
+ CURL* curl_ = nullptr;
+ curl_slist* curl_headers_ = nullptr;
+
+ std::unique_ptr<char[]> default_response_buffer_;
+ StringPiece default_response_string_piece_;
+
+ // Members to enforce the usage flow.
+ bool is_initialized_ = false;
+ bool is_uri_set_ = false;
+ bool is_method_set_ = false;
+ bool is_sent_ = false;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(HttpRequest);
+};
+
+/// \brief A proxy to the libcurl C interface as a dependency injection measure.
+///
+/// This class is meant as a very thin wrapper for the libcurl C library.
+class LibCurl {
+ public:
+ virtual ~LibCurl() {}
+ /// Lazy initialization of the dynamic libcurl library.
+ virtual Status MaybeLoadDll() = 0;
+
+ virtual CURL* curl_easy_init() = 0;
+ virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ uint64 param) = 0;
+ virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ const char* param) = 0;
+ virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ void* param) = 0;
+ virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(void*, size_t, size_t,
+ FILE*)) = 0;
+ virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(const void*, size_t, size_t,
+ void*)) = 0;
+ virtual CURLcode curl_easy_perform(CURL* curl) = 0;
+ virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ uint64* value) = 0;
+ virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ double* value) = 0;
+ virtual void curl_easy_cleanup(CURL* curl) = 0;
+ virtual curl_slist* curl_slist_append(curl_slist* list, const char* str) = 0;
+ virtual void curl_slist_free_all(curl_slist* list) = 0;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
diff --git a/tensorflow/core/platform/cloud/http_request_test.cc b/tensorflow/core/platform/cloud/http_request_test.cc
new file mode 100644
index 0000000000..247514c9da
--- /dev/null
+++ b/tensorflow/core/platform/cloud/http_request_test.cc
@@ -0,0 +1,323 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/http_request.h"
+#include <fstream>
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+// A fake proxy that pretends to be libcurl.
+class FakeLibCurl : public LibCurl {
+ public:
+ FakeLibCurl(const string& response_content, uint64 response_code)
+ : response_content(response_content), response_code(response_code) {}
+ Status MaybeLoadDll() override { return Status::OK(); }
+ CURL* curl_easy_init() override {
+ is_initialized = true;
+ // The reuslt just needs to be non-null.
+ return reinterpret_cast<CURL*>(this);
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ uint64 param) override {
+ switch (option) {
+ case CURLOPT_POST:
+ is_post = param;
+ break;
+ default:
+ break;
+ }
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ const char* param) override {
+ return curl_easy_setopt(curl, option,
+ reinterpret_cast<void*>(const_cast<char*>(param)));
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ void* param) override {
+ switch (option) {
+ case CURLOPT_URL:
+ url = reinterpret_cast<char*>(param);
+ break;
+ case CURLOPT_RANGE:
+ range = reinterpret_cast<char*>(param);
+ break;
+ case CURLOPT_CUSTOMREQUEST:
+ custom_request = reinterpret_cast<char*>(param);
+ break;
+ case CURLOPT_HTTPHEADER:
+ headers = reinterpret_cast<std::vector<string>*>(param);
+ break;
+ case CURLOPT_ERRORBUFFER:
+ error_buffer = reinterpret_cast<char*>(param);
+ break;
+ case CURLOPT_WRITEDATA:
+ write_data = reinterpret_cast<FILE*>(param);
+ break;
+ case CURLOPT_READDATA:
+ read_data = reinterpret_cast<FILE*>(param);
+ break;
+ default:
+ break;
+ }
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(void*, size_t, size_t,
+ FILE*)) override {
+ EXPECT_EQ(param, &fread) << "Expected the standard fread() function.";
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ size_t (*param)(const void*, size_t, size_t,
+ void*)) override {
+ switch (option) {
+ case CURLOPT_WRITEFUNCTION:
+ write_callback = param;
+ break;
+ default:
+ break;
+ }
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_perform(CURL* curl) override {
+ if (read_data) {
+ char buffer[100];
+ int bytes_read;
+ posted_content = "";
+ do {
+ bytes_read = fread(buffer, 1, 100, read_data);
+ posted_content =
+ strings::StrCat(posted_content, StringPiece(buffer, bytes_read));
+ } while (bytes_read > 0);
+ }
+ if (write_data) {
+ write_callback(response_content.c_str(), 1, response_content.size(),
+ write_data);
+ }
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ uint64* value) override {
+ switch (info) {
+ case CURLINFO_RESPONSE_CODE:
+ *value = response_code;
+ break;
+ default:
+ break;
+ }
+ return CURLE_OK;
+ }
+ CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
+ double* value) override {
+ switch (info) {
+ case CURLINFO_SIZE_DOWNLOAD:
+ *value = response_content.size();
+ break;
+ default:
+ break;
+ }
+ return CURLE_OK;
+ }
+ void curl_easy_cleanup(CURL* curl) override { is_cleaned_up = true; }
+ curl_slist* curl_slist_append(curl_slist* list, const char* str) override {
+ std::vector<string>* v = list ? reinterpret_cast<std::vector<string>*>(list)
+ : new std::vector<string>();
+ v->push_back(str);
+ return reinterpret_cast<curl_slist*>(v);
+ }
+ void curl_slist_free_all(curl_slist* list) override {
+ delete reinterpret_cast<std::vector<string>*>(list);
+ }
+
+ // Variables defining the behavior of this fake.
+ string response_content;
+ uint64 response_code;
+
+ // Internal variables to store the libcurl state.
+ string url;
+ string range;
+ string custom_request;
+ char* error_buffer = nullptr;
+ bool is_initialized = false;
+ bool is_cleaned_up = false;
+ std::vector<string>* headers = nullptr;
+ FILE* read_data = nullptr;
+ bool is_post = false;
+ void* write_data = nullptr;
+ size_t (*write_callback)(const void* ptr, size_t size, size_t nmemb,
+ void* userdata) = nullptr;
+ // Outcome of performing the request.
+ string posted_content;
+};
+
+TEST(HttpRequestTest, GetRequest) {
+ FakeLibCurl* libcurl = new FakeLibCurl("get response", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ char scratch[100] = "random original scratch content";
+ StringPiece result = "random original string piece";
+
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
+ TF_EXPECT_OK(http_request.SetRange(100, 199));
+ TF_EXPECT_OK(http_request.SetResultBuffer(scratch, 100, &result));
+ TF_EXPECT_OK(http_request.Send());
+
+ EXPECT_EQ("get response", result);
+
+ // Check interactions with libcurl.
+ EXPECT_TRUE(libcurl->is_initialized);
+ EXPECT_EQ("http://www.testuri.com", libcurl->url);
+ EXPECT_EQ("100-199", libcurl->range);
+ EXPECT_EQ("", libcurl->custom_request);
+ EXPECT_EQ(1, libcurl->headers->size());
+ EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
+ EXPECT_FALSE(libcurl->is_post);
+}
+
+TEST(HttpRequestTest, PostRequest_WithBody) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ auto content_filename = io::JoinPath(testing::TmpDir(), "content");
+ std::ofstream content(content_filename, std::ofstream::binary);
+ content << "post body content";
+ content.close();
+
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
+ TF_EXPECT_OK(http_request.SetPostRequest(content_filename));
+ TF_EXPECT_OK(http_request.Send());
+
+ // Check interactions with libcurl.
+ EXPECT_TRUE(libcurl->is_initialized);
+ EXPECT_EQ("http://www.testuri.com", libcurl->url);
+ EXPECT_EQ("", libcurl->custom_request);
+ EXPECT_EQ(2, libcurl->headers->size());
+ EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
+ EXPECT_EQ("Content-Length: 17", (*libcurl->headers)[1]);
+ EXPECT_TRUE(libcurl->is_post);
+ EXPECT_EQ("post body content", libcurl->posted_content);
+
+ std::remove(content_filename.c_str());
+}
+
+TEST(HttpRequestTest, PostRequest_WithoutBody) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
+ TF_EXPECT_OK(http_request.SetPostRequest());
+ TF_EXPECT_OK(http_request.Send());
+
+ // Check interactions with libcurl.
+ EXPECT_TRUE(libcurl->is_initialized);
+ EXPECT_EQ("http://www.testuri.com", libcurl->url);
+ EXPECT_EQ("", libcurl->custom_request);
+ EXPECT_EQ(2, libcurl->headers->size());
+ EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
+ EXPECT_EQ("Content-Length: 0", (*libcurl->headers)[1]);
+ EXPECT_TRUE(libcurl->is_post);
+ EXPECT_EQ("", libcurl->posted_content);
+}
+
+TEST(HttpRequestTest, DeleteRequest) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
+ TF_EXPECT_OK(http_request.SetDeleteRequest());
+ TF_EXPECT_OK(http_request.Send());
+
+ // Check interactions with libcurl.
+ EXPECT_TRUE(libcurl->is_initialized);
+ EXPECT_EQ("http://www.testuri.com", libcurl->url);
+ EXPECT_EQ("DELETE", libcurl->custom_request);
+ EXPECT_EQ(1, libcurl->headers->size());
+ EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl->headers)[0]);
+ EXPECT_FALSE(libcurl->is_post);
+}
+
+TEST(HttpRequestTest, WrongSequenceOfCalls_NoUri) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ auto s = http_request.Send();
+ ASSERT_TRUE(errors::IsFailedPrecondition(s));
+ EXPECT_TRUE(StringPiece(s.error_message()).contains("URI has not been set"));
+}
+
+TEST(HttpRequestTest, WrongSequenceOfCalls_TwoSends) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ http_request.SetUri("http://www.google.com");
+ http_request.Send();
+ auto s = http_request.Send();
+ ASSERT_TRUE(errors::IsFailedPrecondition(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("The request has already been sent"));
+}
+
+TEST(HttpRequestTest, WrongSequenceOfCalls_ReusingAfterSend) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ http_request.SetUri("http://www.google.com");
+ http_request.Send();
+ auto s = http_request.SetUri("http://mail.google.com");
+ ASSERT_TRUE(errors::IsFailedPrecondition(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("The request has already been sent"));
+}
+
+TEST(HttpRequestTest, WrongSequenceOfCalls_SettingMethodTwice) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+ TF_EXPECT_OK(http_request.Init());
+
+ http_request.SetDeleteRequest();
+ auto s = http_request.SetPostRequest();
+ ASSERT_TRUE(errors::IsFailedPrecondition(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("HTTP method has been already set"));
+}
+
+TEST(HttpRequestTest, WrongSequenceOfCalls_NotInitialized) {
+ FakeLibCurl* libcurl = new FakeLibCurl("", 200);
+ HttpRequest http_request((std::unique_ptr<LibCurl>(libcurl)));
+
+ auto s = http_request.SetPostRequest();
+ ASSERT_TRUE(errors::IsFailedPrecondition(s));
+ EXPECT_TRUE(StringPiece(s.error_message())
+ .contains("The object has not been initialized"));
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 6b3d85ded4..c20db73079 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -3,6 +3,9 @@
load("//google/protobuf:protobuf.bzl", "cc_proto_library")
load("//google/protobuf:protobuf.bzl", "py_proto_library")
+# configure may change the following line to True
+WITH_GCP_SUPPORT = False
+
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
tf_deps = []
@@ -91,3 +94,7 @@ def tf_additional_test_srcs():
def tf_kernel_tests_linkstatic():
return 0
+
+def tf_additional_lib_deps():
+ return (["//tensorflow/core/platform/cloud:gcs_file_system"]
+ if WITH_GCP_SUPPORT else [])
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 596f5b86e3..e3a841468b 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -32,6 +32,7 @@ apt-get install -y \
gfortran \
libatlas-base-dev \
libblas-dev \
+ libcurl4-openssl-dev \
liblapack-dev \
libtool \
openjdk-8-jdk \
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 81672eca6e..8e89b217f7 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -96,3 +96,15 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
name = "grpc_lib",
actual = "@grpc//:grpc++_unsecure",
)
+
+ native.new_git_repository(
+ name = "jsoncpp_git",
+ remote = "https://github.com/open-source-parsers/jsoncpp.git",
+ commit = "11086dd6a7eba04289944367ca82cea71299ed70",
+ build_file = path_prefix + "jsoncpp.BUILD",
+ )
+
+ native.bind(
+ name = "jsoncpp",
+ actual = "@jsoncpp_git//:jsoncpp",
+ )