diff options
author | 2016-06-23 09:29:53 -0800 | |
---|---|---|
committer | 2016-06-23 10:33:05 -0700 | |
commit | 8bf25a491b60d223bba11233de9e62f4b0db17e8 (patch) | |
tree | 6504e4deb37836abd3d4b22b28ee277768be4032 | |
parent | 7683b6820df9179ef987c062f646f267ff5f523c (diff) |
Add a read-ahead cache to the GCS implementation of RandomAccessFile.
In some cases TensorFlow reads the data via RandomAccessFile in really small
chunks, which doesn't work very efficiently with HTTP requests. Adding a
read-ahead cache significantly boosts the performance.
Change: 125691397
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.cc | 72 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.h | 8 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system_test.cc | 100 |
3 files changed, 150 insertions, 30 deletions
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 6025f0cdfb..b9c9a04397 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/platform/cloud/gcs_file_system.h" #include <stdio.h> #include <unistd.h> +#include <algorithm> #include <cstdio> #include <cstdlib> #include <cstring> @@ -80,19 +81,58 @@ Status ParseGcsPath(const string& fname, string* bucket, string* object) { return Status::OK(); } -/// GCS-based implementation of a random access file. +/// A GCS-based implementation of a random access file with a read-ahead buffer. class GcsRandomAccessFile : public RandomAccessFile { public: GcsRandomAccessFile(const string& bucket, const string& object, AuthProvider* auth_provider, - HttpRequest::Factory* http_request_factory) + HttpRequest::Factory* http_request_factory, + size_t read_ahead_bytes) : bucket_(bucket), object_(object), auth_provider_(auth_provider), - http_request_factory_(std::move(http_request_factory)) {} + http_request_factory_(std::move(http_request_factory)), + read_ahead_bytes_(read_ahead_bytes) {} + /// The implementation of reads with a read-ahead buffer. Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { + if (offset >= buffer_start_offset_ && + offset + n <= buffer_start_offset_ + buffer_content_size_) { + // If the requested range is fully in the buffer, just return it. + std::memcpy(scratch, buffer_.get() + offset - buffer_start_offset_, n); + *result = StringPiece(scratch, n); + return Status::OK(); + } + + // Update the buffer content based on the new requested range. + auto buffer_size = n + read_ahead_bytes_; + buffer_.reset(new char[buffer_size]); + buffer_start_offset_ = offset; + buffer_content_size_ = 0; + StringPiece buffer_content; + TF_RETURN_IF_ERROR( + ReadFromGCS(offset, buffer_size, &buffer_content, buffer_.get())); + buffer_content_size_ = buffer_content.size(); + + // Set the results. + *result = StringPiece(scratch, std::min(buffer_content_size_, n)); + std::memcpy(scratch, buffer_.get(), result->size()); + + if (result->size() < n) { + // This is not an error per se. The RandomAccessFile interface expects + // that Read returns OutOfRange if fewer bytes were read than requested. + return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(), + " bytes were read out of ", n, + " bytes requested.")); + } + return Status::OK(); + } + + private: + /// A helper function to actually read the data from GCS. + Status ReadFromGCS(uint64 offset, size_t n, StringPiece* result, + char* scratch) const { string auth_token; TF_RETURN_IF_ERROR(AuthProvider::GetToken(auth_provider_, &auth_token)); @@ -105,22 +145,21 @@ class GcsRandomAccessFile : public RandomAccessFile { TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); TF_RETURN_IF_ERROR(request->SetResultBuffer(scratch, n, result)); TF_RETURN_IF_ERROR(request->Send()); - - if (result->size() < n) { - // This is not an error per se. The RandomAccessFile interface expects - // that Read returns OutOfRange if fewer bytes were read than requested. - return errors::OutOfRange(strings::StrCat("EOF reached, ", result->size(), - " bytes were read out of ", n, - " bytes requested.")); - } return Status::OK(); } - private: string bucket_; string object_; AuthProvider* auth_provider_; HttpRequest::Factory* http_request_factory_; + const size_t read_ahead_bytes_; + + // The buffer-related members need to be mutable, because they are modified + // by the const Read() method. + mutable std::unique_ptr<char[]> buffer_; + // The original file offset of the first byte in the buffer. + mutable size_t buffer_start_offset_ = 0; + mutable size_t buffer_content_size_ = 0; }; /// \brief GCS-based implementation of a writeable file. @@ -233,16 +272,19 @@ GcsFileSystem::GcsFileSystem() GcsFileSystem::GcsFileSystem( std::unique_ptr<AuthProvider> auth_provider, - std::unique_ptr<HttpRequest::Factory> http_request_factory) + std::unique_ptr<HttpRequest::Factory> http_request_factory, + size_t read_ahead_bytes) : auth_provider_(std::move(auth_provider)), - http_request_factory_(std::move(http_request_factory)) {} + http_request_factory_(std::move(http_request_factory)), + read_ahead_bytes_(read_ahead_bytes) {} Status GcsFileSystem::NewRandomAccessFile( const string& fname, std::unique_ptr<RandomAccessFile>* result) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(fname, &bucket, &object)); result->reset(new GcsRandomAccessFile(bucket, object, auth_provider_.get(), - http_request_factory_.get())); + http_request_factory_.get(), + read_ahead_bytes_)); return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 572b1253f8..13785f2db8 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -30,7 +30,8 @@ class GcsFileSystem : public FileSystem { public: GcsFileSystem(); GcsFileSystem(std::unique_ptr<AuthProvider> auth_provider, - std::unique_ptr<HttpRequest::Factory> http_request_factory); + std::unique_ptr<HttpRequest::Factory> http_request_factory, + size_t read_ahead_bytes); Status NewRandomAccessFile( const string& filename, @@ -63,6 +64,11 @@ class GcsFileSystem : public FileSystem { private: std::unique_ptr<AuthProvider> auth_provider_; std::unique_ptr<HttpRequest::Factory> http_request_factory_; + + // The number of bytes to read ahead for buffering purposes in the + // RandomAccessFile implementation. Defaults to 256Mb. + const size_t read_ahead_bytes_ = 256 * 1024 * 1024; + TF_DISALLOW_COPY_AND_ASSIGN(GcsFileSystem); }; diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 286f157528..9c195ba144 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -49,7 +49,7 @@ class FakeAuthProvider : public AuthProvider { } }; -TEST(GcsFileSystemTest, NewRandomAccessFile) { +TEST(GcsFileSystemTest, NewRandomAccessFile_NoReadAhead) { std::vector<HttpRequest*> requests( {new FakeHttpRequest( "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" @@ -63,7 +63,8 @@ TEST(GcsFileSystemTest, NewRandomAccessFile) { "6789")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::unique_ptr<RandomAccessFile> file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -82,6 +83,65 @@ TEST(GcsFileSystemTest, NewRandomAccessFile) { EXPECT_EQ("6789", result); } +TEST(GcsFileSystemTest, NewRandomAccessFile_WithReadAhead) { + std::vector<HttpRequest*> requests( + {new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 0-8\n", + "01234567"), + new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 6-15\n", + "6789abcd"), + new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 6-20\n", + "6789abcd"), + new FakeHttpRequest( + "Uri: https://bucket.storage.googleapis.com/random_access.txt\n" + "Auth Token: fake_token\n" + "Range: 15-29\n", + "")}); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + 5 /* read ahead bytes */); + + std::unique_ptr<RandomAccessFile> file; + TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); + + char scratch[100]; + StringPiece result; + + // Read the first chunk. The cache will be updated with 4 + 5 = 9 bytes. + TF_EXPECT_OK(file->Read(0, 4, &result, scratch)); + EXPECT_EQ("0123", result); + + // The second chunk will be fully loaded from the cache, no requests are made. + TF_EXPECT_OK(file->Read(4, 4, &result, scratch)); + EXPECT_EQ("4567", result); + + // The chunk is only partially cached -- the request will be made to + // reload the cache. 5 + 5 = 10 bytes will be requested. + TF_EXPECT_OK(file->Read(6, 5, &result, scratch)); + EXPECT_EQ("6789a", result); + + // The range can only be partially satisfied. An attempt to fill the cache + // with 10 + 5 = 15 bytes will be made. + EXPECT_EQ(errors::Code::OUT_OF_RANGE, + file->Read(6, 10, &result, scratch).code()); + EXPECT_EQ("6789abcd", result); + + // The range cannot be satisfied. An attempt to fill the cache + // with 10 + 5 = 15 bytes will be made. + EXPECT_EQ(errors::Code::OUT_OF_RANGE, + file->Read(15, 10, &result, scratch).code()); + EXPECT_TRUE(result.empty()); +} + TEST(GcsFileSystemTest, NewWritableFile) { std::vector<HttpRequest*> requests({new FakeHttpRequest( "Uri: https://www.googleapis.com/upload/storage/v1/b/bucket/o?" @@ -91,7 +151,8 @@ TEST(GcsFileSystemTest, NewWritableFile) { "")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::unique_ptr<WritableFile> file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -116,7 +177,8 @@ TEST(GcsFileSystemTest, NewAppendableFile) { "")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::unique_ptr<WritableFile> file; TF_EXPECT_OK(fs.NewAppendableFile("gs://bucket/path/appendable.txt", &file)); @@ -142,7 +204,8 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { content)}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::unique_ptr<ReadOnlyMemoryRegion> region; TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile( @@ -166,7 +229,8 @@ TEST(GcsFileSystemTest, FileExists) { "", errors::NotFound("404"))}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); EXPECT_TRUE(fs.FileExists("gs://bucket/path/file1.txt")); EXPECT_FALSE(fs.FileExists("gs://bucket/path/file2.txt")); @@ -176,7 +240,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { auto requests = CreateGetThreeChildrenRequest(); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -188,7 +253,8 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { auto requests = CreateGetThreeChildrenRequest(); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children)); @@ -204,7 +270,8 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { "{}")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -221,7 +288,8 @@ TEST(GcsFileSystemTest, DeleteFile) { "")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); TF_EXPECT_OK(fs.DeleteFile("gs://bucket/path/file1.txt")); } @@ -234,7 +302,8 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { "{}")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); } @@ -248,7 +317,8 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { " { \"name\": \"path/file1.txt\" }]}")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); EXPECT_FALSE(fs.DeleteDir("gs://bucket/path/").ok()); } @@ -261,7 +331,8 @@ TEST(GcsFileSystemTest, GetFileSize) { strings::StrCat("{\"size\": \"1010\"}"))}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); uint64 size; TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size)); @@ -284,7 +355,8 @@ TEST(GcsFileSystemTest, RenameFile) { "")}); GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests))); + new FakeHttpRequestFactory(&requests)), + 0 /* read ahead bytes */); TF_EXPECT_OK( fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt")); |