aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexey Surkov <surkov@google.com>2016-06-23 09:29:53 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-23 10:33:05 -0700
commit8bf25a491b60d223bba11233de9e62f4b0db17e8 (patch)
tree6504e4deb37836abd3d4b22b28ee277768be4032
parent7683b6820df9179ef987c062f646f267ff5f523c (diff)
Add a read-ahead cache to the GCS implementation of RandomAccessFile.
In some cases TensorFlow reads the data via RandomAccessFile in really small chunks, which doesn't work very efficiently with HTTP requests. Adding a read-ahead cache significantly boosts the performance. Change: 125691397
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc72
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h8
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc100
3 files changed, 150 insertions, 30 deletions
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 6025f0cdfb..b9c9a04397 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/platform/cloud/gcs_file_system.h"
#include <stdio.h>
#include <unistd.h>
+#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
@@ -80,19 +81,58 @@ Status ParseGcsPath(const string& fname, string* bucket, string* object) {
return Status::OK();
}
-/// GCS-based implementation of a random access file.
+/// A GCS-based implementation of a random access file with a read-ahead buffer.
class GcsRandomAccessFile : public RandomAccessFile {
public:
GcsRandomAccessFile(const string& bucket, const string& object,
AuthProvider* auth_provider,
- HttpRequest::Factory* http_request_factory)
+ HttpRequest::Factory* http_request_factory,
+ size_t read_ahead_bytes)
: bucket_(bucket),
object_(object),
auth_provider_(auth_provider),
- http_request_factory_(std::move(http_request_factory)) {}
+ http_request_factory_(std::move(http_request_factory)),
+ read_ahead_bytes_(read_ahead_bytes) {}
+ /// The implementation of reads with a read-ahead buffer.
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
+ if (offset >= buffer_start_offset_ &&
+ offset + n <= buffer_start_offset_ + buffer_content_size_) {
+ // If the requested range is fully in the buffer, just return it.
+ std::memcpy(scratch, buffer_.get() + offset - buffer_start_offset_, n);
+ *result = StringPiece(scratch, n);
+ return Status::OK();
+ }
+
+ // Update the buffer content based on the new requested range.
+ auto buffer_size = n + read_ahead_bytes_;
+ buffer_.reset(new char[buffer_size]);
+ buffer_start_offset_ = offset;
+ buffer_content_size_ = 0;
+ StringPiece buffer_content;
+ TF_RETURN_IF_ERROR(
+ ReadFromGCS(offset, buffer_size, &buffer_content, buffer_.get()));
+ buffer_content_size_ = buffer_content.size();
+
+ // Set the results.
+ *result = StringPiece(scratch, std::min(buffer_content_size_, n));
+ std::memcpy(scratch, buffer_.get(), result->size());
+
+ if (result->size() < n) {
+ // This is not an error per se. The RandomAccessFile interface expects
+ // that Read returns OutOfRange if fewer bytes were read than requested.
+ return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(),
+ " bytes were read out of ", n,
+ " bytes requested."));
+ }
+ return Status::OK();
+ }
+
+ private:
+ /// A helper function to actually read the data from GCS.
+ Status ReadFromGCS(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const {
string auth_token;
TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_, &auth_token));
@@ -105,22 +145,21 @@ class GcsRandomAccessFile : public RandomAccessFile {
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
TF_RETURN_IF_ERROR(request->SetResultBuffer(scratch, n, result));
TF_RETURN_IF_ERROR(request->Send());
-
- if (result->size() < n) {
- // This is not an error per se. The RandomAccessFile interface expects
- // that Read returns OutOfRange if fewer bytes were read than requested.
- return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(),
- " bytes were read out of ", n,
- " bytes requested."));
- }
return Status::OK();
}
- private:
string bucket_;
string object_;
AuthProvider* auth_provider_;
HttpRequest::Factory* http_request_factory_;
+ const size_t read_ahead_bytes_;
+
+ // The buffer-related members need to be mutable, because they are modified
+ // by the const Read() method.
+ mutable std::unique_ptr<char[]> buffer_;
+ // The original file offset of the first byte in the buffer.
+ mutable size_t buffer_start_offset_ = 0;
+ mutable size_t buffer_content_size_ = 0;
};
/// \brief GCS-based implementation of a writeable file.
@@ -233,16 +272,19 @@ GcsFileSystem::GcsFileSystem()
GcsFileSystem::GcsFileSystem(
std::unique_ptr<AuthProvider> auth_provider,
- std::unique_ptr<HttpRequest::Factory> http_request_factory)
+ std::unique_ptr<HttpRequest::Factory> http_request_factory,
+ size_t read_ahead_bytes)
: auth_provider_(std::move(auth_provider)),
- http_request_factory_(std::move(http_request_factory)) {}
+ http_request_factory_(std::move(http_request_factory)),
+ read_ahead_bytes_(read_ahead_bytes) {}
Status GcsFileSystem::NewRandomAccessFile(
const string& fname, std::unique_ptr<RandomAccessFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object));
result->reset(new GcsRandomAccessFile(bucket, object, auth_provider_.get(),
- http_request_factory_.get()));
+ http_request_factory_.get(),
+ read_ahead_bytes_));
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 572b1253f8..13785f2db8 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -30,7 +30,8 @@ class GcsFileSystem : public FileSystem {
public:
GcsFileSystem();
GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider,
- std::unique_ptr<HttpRequest::Factory> http_request_factory);
+ std::unique_ptr<HttpRequest::Factory> http_request_factory,
+ size_t read_ahead_bytes);
Status NewRandomAccessFile(
const string& filename,
@@ -63,6 +64,11 @@ class GcsFileSystem : public FileSystem {
private:
std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
+
+ // The number of bytes to read ahead for buffering purposes in the
+ // RandomAccessFile implementation. Defaults to 256Mb.
+ const size_t read_ahead_bytes_ = 256 * 1024 * 1024;
+
TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem);
};
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 286f157528..9c195ba144 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -49,7 +49,7 @@ class FakeAuthProvider : public AuthProvider {
}
};
-TEST(GcsFileSystemTest, NewRandomAccessFile) {
+TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead) {
std::vector<HttpRequest*> requests(
{new FakeHttpRequest(
"Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
@@ -63,7 +63,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile) {
"6789")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -82,6 +83,65 @@ TEST(GcsFileSystemTest, NewRandomAccessFile) {
EXPECT_EQ("6789", result);
}
+TEST(GcsFileSystemTest, NewRandomAccessFile_WithReadAhead) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 0-8\n",
+ "01234567"),
+ new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 6-15\n",
+ "6789abcd"),
+ new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 6-20\n",
+ "6789abcd"),
+ new FakeHttpRequest(
+ "Uri: https://bucket.storage.googleapis.com/random_access.txt\n"
+ "Auth Token: fake_token\n"
+ "Range: 15-29\n",
+ "")});
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 5 /* read ahead bytes */);
+
+ std::unique_ptr<RandomAccessFile> file;
+ TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
+
+ char scratch[100];
+ StringPiece result;
+
+ // Read the first chunk. The cache will be updated with 4 + 5 = 9 bytes.
+ TF_EXPECT_OK(file->Read(0, 4, &result, scratch));
+ EXPECT_EQ("0123", result);
+
+ // The second chunk will be fully loaded from the cache, no requests are made.
+ TF_EXPECT_OK(file->Read(4, 4, &result, scratch));
+ EXPECT_EQ("4567", result);
+
+ // The chunk is only partially cached -- the request will be made to
+ // reload the cache. 5 + 5 = 10 bytes will be requested.
+ TF_EXPECT_OK(file->Read(6, 5, &result, scratch));
+ EXPECT_EQ("6789a", result);
+
+ // The range can only be partially satisfied. An attempt to fill the cache
+ // with 10 + 5 = 15 bytes will be made.
+ EXPECT_EQ(errors::Code::OUT_OF_RANGE,
+ file->Read(6, 10, &result, scratch).code());
+ EXPECT_EQ("6789abcd", result);
+
+ // The range cannot be satisfied. An attempt to fill the cache
+ // with 10 + 5 = 15 bytes will be made.
+ EXPECT_EQ(errors::Code::OUT_OF_RANGE,
+ file->Read(15, 10, &result, scratch).code());
+ EXPECT_TRUE(result.empty());
+}
+
TEST(GcsFileSystemTest, NewWritableFile) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?"
@@ -91,7 +151,8 @@ TEST(GcsFileSystemTest, NewWritableFile) {
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -116,7 +177,8 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewAppendableFile("gs://bucket/path/appendable.txt", &file));
@@ -142,7 +204,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
content)});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -166,7 +229,8 @@ TEST(GcsFileSystemTest, FileExists) {
"", errors::NotFound("404"))});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
EXPECT_TRUE(fs.FileExists("gs://bucket/path/file1.txt"));
EXPECT_FALSE(fs.FileExists("gs://bucket/path/file2.txt"));
@@ -176,7 +240,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
auto requests = CreateGetThreeChildrenRequest();
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -188,7 +253,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
auto requests = CreateGetThreeChildrenRequest();
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -204,7 +270,8 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
"{}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -221,7 +288,8 @@ TEST(GcsFileSystemTest, DeleteFile) {
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
TF_EXPECT_OK(fs.DeleteFile("gs://bucket/path/file1.txt"));
}
@@ -234,7 +302,8 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
"{}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -248,7 +317,8 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
" { \"name\": \"path/file1.txt\" }]}")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
EXPECT_FALSE(fs.DeleteDir("gs://bucket/path/").ok());
}
@@ -261,7 +331,8 @@ TEST(GcsFileSystemTest, GetFileSize) {
strings::StrCat("{\"size\": \"1010\"}"))});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -284,7 +355,8 @@ TEST(GcsFileSystemTest, RenameFile) {
"")});
GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)));
+ new FakeHttpRequestFactory(&requests)),
+ 0 /* read ahead bytes */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));