diff options
-rw-r--r-- | tensorflow/core/platform/cloud/curl_http_request.cc | 53 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/curl_http_request.h | 32 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/curl_http_request_test.cc | 33 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/file_block_cache.cc | 21 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/file_block_cache.h | 9 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/file_block_cache_test.cc | 181 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_dns_cache_test.cc | 4 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.cc | 28 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/gcs_file_system.h | 2 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/http_request.h | 14 | ||||
-rw-r--r-- | tensorflow/core/platform/cloud/http_request_fake.h | 21 |
11 files changed, 295 insertions, 103 deletions
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index c2533b4314..86943d34a6 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <algorithm> + #include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/lib/core/errors.h" @@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector<char>* out_buffer) { return Status::OK(); } +Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) { + CHECK(buffer != nullptr); + TF_RETURN_IF_ERROR(CheckInitialized()); + TF_RETURN_IF_ERROR(CheckNotSent()); + + direct_response_ = DirectResponseState{buffer, size, 0}; + + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA, + reinterpret_cast<void*>(this)); + libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION, + &CurlHttpRequest::WriteCallbackDirect); + return Status::OK(); +} + +size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size, + size_t nmemb, void* userdata) { + CHECK(ptr != nullptr); + auto that = reinterpret_cast<CurlHttpRequest*>(userdata); + DirectResponseState* state = &that->direct_response_; + CHECK(state->buffer_ != nullptr); + CHECK(state->bytes_transferred_ <= state->buffer_size_); + + size_t curl_bytes_received = size * nmemb; + size_t user_buffer_bytes_available = + state->buffer_size_ - state->bytes_transferred_; + + // The HTTP server may send a response body that is longer than what we + // expected. We must not use CHECK() for this situation, because that would + // imply a code bug (in this client code) where none exists; the violation of + // expectations would have been caused by the server, not the client. So we + // report a log warning, if an HTTP server is misbehaving. + if (curl_bytes_received > user_buffer_bytes_available) { + LOG(WARNING) << "The HTTP response body that we received is longer than we " + "requested or expected. " + << "Total bytes requested: " << state->buffer_size_ + << " Bytes received (so far) in HTTP response body: " + << (state->bytes_transferred_ + curl_bytes_received); + } + + size_t bytes_to_copy = + std::min<size_t>(curl_bytes_received, user_buffer_bytes_available); + memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy); + state->bytes_transferred_ += bytes_to_copy; + return bytes_to_copy; +} + +size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() { + CHECK(direct_response_.buffer_ != nullptr); + return direct_response_.bytes_transferred_; +} + Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity, uint32 total) { TF_RETURN_IF_ERROR(CheckInitialized()); diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h index e4c91dac8d..0686b692cb 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.h +++ b/tensorflow/core/platform/cloud/curl_http_request.h @@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest { /// read. Existing content of the vector will be cleared. Status SetResultBuffer(std::vector<char>* out_buffer) override; + /// \brief Specifies the buffer for receiving the response body, when the + /// caller knows the maximum size of the response body. + /// + /// This method allows the caller to receive the response body without an + /// additional intermediate buffer allocation and copy. This method should + /// be called before calling Send(). After Send() has succeeded, the caller + /// should use the GetResultBufferDirectBytesTransferred() method in order + /// to learn how many bytes were transferred. + /// + /// Using this method is mutually exclusive with using SetResultBuffer(). + Status SetResultBufferDirect(char* buffer, size_t size) override; + + /// \brief Returns the number of bytes (of the response body) that were + /// transferred, when using the SetResultBufferDirect() method. The returned + /// value will always be less than or equal to the 'size' parameter that + /// was passed to SetResultBufferDirect(). If the actual HTTP response body + /// was greater than 'size' bytes, then this transfer method will only copy + /// the first 'size' bytes, and the rest will be ignored. + size_t GetResultBufferDirectBytesTransferred() override; + /// \brief Returns the response headers of a completed request. /// /// If the header is not found, returns an empty string. @@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest { /// A write callback in the form which can be accepted by libcurl. static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb, void* userdata); + + /// Processes response body content received when using SetResultBufferDirect. + static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb, + void* userdata); /// A read callback in the form which can be accepted by libcurl. static size_t ReadCallback(void* ptr, size_t size, size_t nmemb, FILE* userdata); @@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest { size_t post_body_read_ = 0; std::vector<char>* response_buffer_ = nullptr; + + struct DirectResponseState { + char* buffer_; + size_t buffer_size_; + size_t bytes_transferred_; + }; + DirectResponseState direct_response_ = {}; + CURL* curl_ = nullptr; curl_slist* curl_headers_ = nullptr; curl_slist* resolve_list_ = nullptr; diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc index 2d3e46edaf..d108849c0f 100644 --- a/tensorflow/core/platform/cloud/curl_http_request_test.cc +++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc @@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) { EXPECT_EQ(200, http_request.GetResponseCode()); } +TEST(CurlHttpRequestTest, GetRequest_Direct) { + FakeLibCurl libcurl("get response", 200); + CurlHttpRequest http_request(&libcurl); + TF_EXPECT_OK(http_request.Init()); + + std::vector<char> scratch(100, 0); + + TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com")); + TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer")); + TF_EXPECT_OK(http_request.SetRange(100, 199)); + TF_EXPECT_OK( + http_request.SetResultBufferDirect(scratch.data(), scratch.capacity())); + TF_EXPECT_OK(http_request.Send()); + + string expected_response = "get response"; + size_t response_bytes_transferred = + http_request.GetResultBufferDirectBytesTransferred(); + EXPECT_EQ(response_bytes_transferred, expected_response.size()); + EXPECT_EQ( + "get response", + string(scratch.begin(), scratch.begin() + response_bytes_transferred)); + + // Check interactions with libcurl. + EXPECT_TRUE(libcurl.is_initialized_); + EXPECT_EQ("http://www.testuri.com", libcurl.url_); + EXPECT_EQ("100-199", libcurl.range_); + EXPECT_EQ("", libcurl.custom_request_); + EXPECT_EQ(1, libcurl.headers_->size()); + EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]); + EXPECT_FALSE(libcurl.is_post_); + EXPECT_EQ(200, http_request.GetResponseCode()); +} + TEST(CurlHttpRequestTest, GetRequest_Empty) { FakeLibCurl libcurl("", 200); CurlHttpRequest http_request(&libcurl); diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc index e1afc7b308..e6fa93890f 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.cc +++ b/tensorflow/core/platform/cloud/file_block_cache.cc @@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key, case FetchState::CREATED: block->state = FetchState::FETCHING; block->mu.unlock(); // Release the lock while making the API call. - status.Update( - block_fetcher_(key.first, key.second, block_size_, &block->data)); + block->data.clear(); + block->data.resize(block_size_, 0); + size_t bytes_transferred; + status.Update(block_fetcher_(key.first, key.second, block_size_, + block->data.data(), &bytes_transferred)); + block->data.resize(bytes_transferred, 0); block->mu.lock(); // Reacquire the lock immediately afterwards if (status.ok()) { downloaded_block = true; @@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key, } Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, - std::vector<char>* out) { - out->clear(); + char* buffer, size_t* bytes_transferred) { + *bytes_transferred = 0; if (n == 0) { return Status::OK(); } if (block_size_ == 0 || max_bytes_ == 0) { // The cache is effectively disabled, so we pass the read through to the // fetcher without breaking it up into blocks. - return block_fetcher_(filename, offset, n, out); + return block_fetcher_(filename, offset, n, buffer, bytes_transferred); } // Calculate the block-aligned start and end of the read. size_t start = block_size_ * (offset / block_size_); @@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, if (finish < offset + n) { finish += block_size_; } + size_t total_bytes_transferred = 0; // Now iterate through the blocks, reading them one at a time. for (size_t pos = start; pos < finish; pos += block_size_) { Key key = std::make_pair(filename, pos); @@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, // The requested offset is at or beyond the end of the file. This can // happen if `offset` is not block-aligned, and the read returns the last // block in the file, which does not extend all the way out to `offset`. + *bytes_transferred = total_bytes_transferred; return errors::OutOfRange("EOF at offset ", offset, " in file ", filename, " at position ", pos, "with data size ", data.size()); @@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n, end -= (pos + data.size()) - (offset + n); } if (begin < end) { - out->insert(out->end(), begin, end); + size_t bytes_to_copy = end - begin; + memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy); + total_bytes_transferred += bytes_to_copy; } if (data.size() < block_size_) { // The block was a partial block and thus signals EOF at its upper bound. break; } } + *bytes_transferred = total_bytes_transferred; return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h index 36dbf9db83..74e792a625 100644 --- a/tensorflow/core/platform/cloud/file_block_cache.h +++ b/tensorflow/core/platform/cloud/file_block_cache.h @@ -43,8 +43,9 @@ class FileBlockCache { /// cache is constructed. The returned Status should be OK as long as the /// read from the remote filesystem succeeded (similar to the semantics of the /// read(2) system call). - typedef std::function<Status(const string&, size_t, size_t, - std::vector<char>*)> + typedef std::function<Status(const string& filename, size_t offset, + size_t buffer_size, char* buffer, + size_t* bytes_transferred)> BlockFetcher; FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness, @@ -83,8 +84,8 @@ class FileBlockCache { /// placed in `out`. /// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed /// in `out`). - Status Read(const string& filename, size_t offset, size_t n, - std::vector<char>* out); + Status Read(const string& filename, size_t offset, size_t n, char* buffer, + size_t* bytes_transferred); /// Remove all cached blocks for `filename`. void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_); diff --git a/tensorflow/core/platform/cloud/file_block_cache_test.cc b/tensorflow/core/platform/cloud/file_block_cache_test.cc index 081b32af64..12ad44011f 100644 --- a/tensorflow/core/platform/cloud/file_block_cache_test.cc +++ b/tensorflow/core/platform/cloud/file_block_cache_test.cc @@ -25,6 +25,18 @@ limitations under the License. namespace tensorflow { namespace { +Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset, + size_t n, std::vector<char>* out) { + out->clear(); + out->resize(n, 0); + size_t bytes_transferred = 0; + Status status = + cache->Read(filename, offset, n, out->data(), &bytes_transferred); + EXPECT_LE(bytes_transferred, n); + out->resize(bytes_transferred, n); + return status; +} + TEST(FileBlockCacheTest, PassThrough) { const string want_filename = "foo/bar"; const size_t want_offset = 42; @@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) { int calls = 0; auto fetcher = [&calls, want_filename, want_offset, want_n]( const string& got_filename, size_t got_offset, - size_t got_n, std::vector<char>* out) { + size_t got_n, char* buffer, size_t* bytes_transferred) { EXPECT_EQ(got_filename, want_filename); EXPECT_EQ(got_offset, want_offset); EXPECT_EQ(got_n, want_n); calls++; - out->resize(got_n, 'x'); + memset(buffer, 'x', got_n); + *bytes_transferred = got_n; return Status::OK(); }; // If block_size, max_bytes, or both are zero, the cache is a pass-through. @@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) { FileBlockCache cache2(0, 1, 0, fetcher); FileBlockCache cache3(0, 0, 0, fetcher); std::vector<char> out; - TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 1); - TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 2); - TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out)); + TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out)); EXPECT_EQ(calls, 3); } @@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) { } // The fetcher just fetches slices of the buffer. auto fetcher = [&buf](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { if (offset < buf.size()) { - if (offset + n > buf.size()) { - out->insert(out->end(), buf.begin() + offset, buf.end()); - } else { - out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n); - } + size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n); + memcpy(buffer, buf.data() + offset, bytes_to_copy); + *bytes_transferred = bytes_to_copy; + } else { + *bytes_transferred = 0; } return Status::OK(); }; @@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) { for (size_t offset = 0; offset < 10; offset++) { for (size_t n = block_size - 2; n <= block_size + 2; n++) { std::vector<char> got; - TF_EXPECT_OK(cache.Read("", offset, n, &got)); + TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got)); // Verify the size of the read. if (offset + n <= size) { // Expect a full read. @@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) { const size_t block_size = 16; std::set<size_t> calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, std::vector<char>* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset; calls.insert(offset); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const uint32 block_count = 256; FileBlockCache cache(block_size, block_count * block_size, 0, fetcher); std::vector<char> out; + out.resize(block_count, 0); // The cache has space for `block_count` blocks. The loop with i = 0 should // fill the cache, and the loop with i = 1 should be all cache hits. The // fetcher checks that it is called once and only once for each offset (to // fetch the corresponding block). for (int i = 0; i < 2; i++) { for (int j = 0; j < block_count; j++) { - TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out)); } } } @@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) { bool second_block = false; auto fetcher = [block_size, file_size, &first_block, &second_block]( const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); + size_t bytes_to_copy = 0; if (offset == 0) { // The first block (16 bytes) of the file. - out->resize(n, 'x'); + memset(buffer, 'x', n); + bytes_to_copy = n; first_block = true; } else if (offset == block_size) { // The second block (8 bytes) of the file. - out->resize(file_size - block_size, 'x'); + bytes_to_copy = file_size - block_size; + memset(buffer, 'x', bytes_to_copy); second_block = true; } + *bytes_transferred = bytes_to_copy; return Status::OK(); }; FileBlockCache cache(block_size, block_size, 0, fetcher); std::vector<char> out; // Reading the first 16 bytes should be fine. - TF_EXPECT_OK(cache.Read("", 0, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out)); EXPECT_TRUE(first_block); EXPECT_EQ(out.size(), block_size); // Reading at offset file_size + 4 will read the second block (since the read // at file_size + 4 = 28 will be aligned to an offset of 16) but will return // OutOfRange because the offset is past the end of the 24-byte file. - Status status = cache.Read("", file_size + 4, 4, &out); + Status status = ReadCache(&cache, "", file_size + 4, 4, &out); EXPECT_EQ(status.code(), error::OUT_OF_RANGE); EXPECT_TRUE(second_block); - EXPECT_EQ(out.size(), 0); // Reading the second full block will return 8 bytes, from a cache hit. second_block = false; - TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); EXPECT_FALSE(second_block); EXPECT_EQ(out.size(), file_size - block_size); } @@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) { const size_t block_size = 16; // This fetcher returns OK but only fills in one byte for any offset. auto fetcher = [block_size](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset % block_size, 0); - out->resize(1, 'x'); + EXPECT_GE(n, 1); + memset(buffer, 'x', 1); + *bytes_transferred = 1; return Status::OK(); }; FileBlockCache cache(block_size, 2 * block_size, 0, fetcher); std::vector<char> out; // Read the second block; this should yield an OK status and a single byte. - TF_EXPECT_OK(cache.Read("", block_size, block_size, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out)); EXPECT_EQ(out.size(), 1); // Now read the first block; this should yield an INTERNAL error because we // had already cached a partial block at a later position. - Status status = cache.Read("", 0, block_size, &out); + Status status = ReadCache(&cache, "", 0, block_size, &out); EXPECT_EQ(status.code(), error::INTERNAL); } @@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) { const size_t block_size = 16; std::list<size_t> calls; auto fetcher = [&calls, block_size](const string& filename, size_t offset, - size_t n, std::vector<char>* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_FALSE(calls.empty()) << "at offset = " << offset; if (!calls.empty()) { EXPECT_EQ(offset, calls.front()); calls.pop_front(); } - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const uint32 block_count = 2; @@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) { // fetcher calls that the cache makes. calls.push_back(0); // Cache miss - drains an element from `calls`. - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); // Cache hit - does not drain an element from `calls`. - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); calls.push_back(block_size); // Cache miss followed by cache hit. - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); calls.push_back(2 * block_size); // Cache miss followed by cache hit. Causes eviction of LRU element. - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); // LRU element was at offset 0. Cache miss. calls.push_back(0); - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); // Element at 2 * block_size is still in cache, and this read should update // its position in the LRU list so it doesn't get evicted by the next read. - TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out)); // Element at block_size was evicted. Reading this element will also cause // the LRU element (at 0) to be evicted. calls.push_back(block_size); - TF_EXPECT_OK(cache.Read("", block_size, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out)); // Element at 0 was evicted again. calls.push_back(0); - TF_EXPECT_OK(cache.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out)); } TEST(FileBlockCacheTest, MaxStaleness) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { calls++; - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; std::vector<char> out; @@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) { // expected. FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get()); // Execute the first read to load the block. - TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); EXPECT_EQ(calls, 1); // Now advance the clock one second at a time and redo the read. The call // count should advance every 3 seconds (i.e. every time the staleness is // greater than 2). for (int i = 1; i <= 10; i++) { env->SetNowSeconds(i + 1); - TF_EXPECT_OK(cache1.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out)); EXPECT_EQ(calls, 1 + i / 3); } // Now create a cache with max staleness of 0, and verify that it also works @@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) { env->SetNowSeconds(0); FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get()); // Execute the first read to load the block. - TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); EXPECT_EQ(calls, 1); // Advance the clock by a huge amount and verify that the cached block is // used to satisfy the read. env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun. - TF_EXPECT_OK(cache2.Read("", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out)); EXPECT_EQ(calls, 1); } TEST(FileBlockCacheTest, RemoveFile) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { calls++; char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x'; if (offset > 0) { // The first block is lower case and all subsequent blocks are upper case. c = toupper(c); } - out->clear(); - out->resize(n, c); + memset(buffer, c, n); + *bytes_transferred = n; return Status::OK(); }; // This cache has space for 4 blocks; we'll read from two files. @@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) { std::vector<char> A(n, 'A'); std::vector<char> B(n, 'B'); // Fill the cache. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); EXPECT_EQ(calls, 1); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); EXPECT_EQ(calls, 2); - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); EXPECT_EQ(calls, 3); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // All four blocks should be in the cache now. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // Remove the blocks from "a". cache.RemoveFile("a"); // Both blocks from "b" should still be there. - TF_EXPECT_OK(cache.Read("b", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out)); EXPECT_EQ(out, b); - TF_EXPECT_OK(cache.Read("b", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out)); EXPECT_EQ(out, B); EXPECT_EQ(calls, 4); // The blocks from "a" should not be there. - TF_EXPECT_OK(cache.Read("a", 0, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out)); EXPECT_EQ(out, a); EXPECT_EQ(calls, 5); - TF_EXPECT_OK(cache.Read("a", 8, n, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out)); EXPECT_EQ(out, A); EXPECT_EQ(calls, 6); } @@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) { TEST(FileBlockCacheTest, Prune) { int calls = 0; auto fetcher = [&calls](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { calls++; - out->clear(); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; std::vector<char> out; @@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) { FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get()); // Read three blocks into the cache, and advance the timestamp by one second // with each read. Start with a block of "a" at the current timestamp `now`. - TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); // Now load a block of a different file "b" at timestamp `now` + 1 env->SetNowSeconds(now + 1); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); // Now load a different block of file "a" at timestamp `now` + 1. When the // first block of "a" expires, this block should also be removed because it // also belongs to file "a". - TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); // Ensure that all blocks are in the cache (i.e. reads are cache hits). EXPECT_EQ(cache.CacheSize(), 24); EXPECT_EQ(calls, 3); - TF_EXPECT_OK(cache.Read("a", 0, 1, &out)); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); - TF_EXPECT_OK(cache.Read("a", 8, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out)); EXPECT_EQ(calls, 3); // Advance the fake timestamp so that "a" becomes stale via its first block. env->SetNowSeconds(now + 2); @@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) { // There should be one block left in the cache, and it should be the first // block of "b". EXPECT_EQ(cache.CacheSize(), 8); - TF_EXPECT_OK(cache.Read("b", 0, 1, &out)); + TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out)); EXPECT_EQ(calls, 3); // Advance the fake time to `now` + 3, at which point "b" becomes stale. env->SetNowSeconds(now + 3); @@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) { const int callers = 4; BlockingCounter counter(callers); auto fetcher = [&counter](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { counter.DecrementCount(); if (!counter.WaitFor(std::chrono::seconds(10))) { // This avoids having the test time out, which is harder to debug. return errors::FailedPrecondition("desired concurrency not reached"); } - out->clear(); - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; return Status::OK(); }; const int block_size = 8; @@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) { threads.emplace_back( Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() { std::vector<char> out; - TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out)); + TF_EXPECT_OK( + ReadCache(&cache, "a", i * block_size, block_size, &out)); std::vector<char> x(block_size, 'x'); EXPECT_EQ(out, x); })); @@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) { Notification notification; auto fetcher = [&num_requests, ¬ification, block_size]( const string& filename, size_t offset, size_t n, - std::vector<char>* out) { + char* buffer, size_t* bytes_transferred) { EXPECT_EQ(n, block_size); EXPECT_EQ(offset, 0); num_requests++; - out->resize(n, 'x'); + memset(buffer, 'x', n); + *bytes_transferred = n; notification.Notify(); // Wait for other thread to issue read. Env::Default()->SleepForMicroseconds(100000); // 0.1 secs @@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) { std::unique_ptr<Thread> concurrent( Env::Default()->StartThread({}, "concurrent", [&cache] { std::vector<char> out; - TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out)); EXPECT_EQ(out.size(), block_size / 2); })); EXPECT_TRUE(WaitForNotificationWithTimeout(¬ification, 10000)) << "Timeout waiting for concurrent thread to start."; std::vector<char> out; - TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out)); + TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out)); EXPECT_EQ(out.size(), block_size / 2); EXPECT_EQ(1, num_requests); } - } // namespace } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc index 2c3819f1e2..c96d364228 100644 --- a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc +++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc @@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest { Status SetResultBuffer(std::vector<char>* out_buffer) override { return Status::OK(); } + Status SetResultBufferDirect(char* buffer, size_t size) override { + return Status::OK(); + } + size_t GetResultBufferDirectBytesTransferred() override { return 0; } string GetResponseHeader(const string& name) const override { return ""; } uint64 GetResponseCode() const override { return 0; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index a183fe6fa8..ec66ab01d6 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile { Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { *result = StringPiece(); - std::vector<char> out; - TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out)); - std::memcpy(scratch, out.data(), std::min(out.size(), n)); - *result = StringPiece(scratch, std::min(out.size(), n)); - if (result->size() < n) { + size_t bytes_transferred; + TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch, + &bytes_transferred)); + *result = StringPiece(scratch, bytes_transferred); + if (bytes_transferred < n) { // This is not an error per se. The RandomAccessFile interface expects // that Read returns OutOfRange if fewer bytes were read than requested. return errors::OutOfRange("EOF reached, ", result->size(), @@ -721,15 +721,17 @@ std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache( std::unique_ptr<FileBlockCache> file_block_cache( new FileBlockCache(block_size, max_bytes, max_staleness, [this](const string& filename, size_t offset, size_t n, - std::vector<char>* out) { - return LoadBufferFromGCS(filename, offset, n, out); + char* buffer, size_t* bytes_transferred) { + return LoadBufferFromGCS(filename, offset, n, buffer, + bytes_transferred); })); return file_block_cache; } // A helper function to actually read the data from GCS. Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, - size_t n, std::vector<char>* out) { + size_t n, char* buffer, + size_t* bytes_transferred) { string bucket, object; TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object)); @@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset, request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket, "/", request->EscapeString(object)))); TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(out)); + TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n)); TF_RETURN_IF_ERROR( request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read)); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://", bucket, "/", object); + size_t bytes_read = request->GetResultBufferDirectBytesTransferred(); + *bytes_transferred = bytes_read; VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ " - << offset << " of size: " << out->size(); + << offset << " of size: " << bytes_read; - if (out->size() < block_size()) { + if (bytes_read < block_size()) { // Check stat cache to see if we encountered an interrupted read. FileStatistics stat; if (stat_cache_->Lookup(filename, &stat)) { - if (offset + out->size() < stat.length) { + if (offset + bytes_read < stat.length) { return errors::Internal(strings::Printf( "File contents are inconsistent for file: %s @ %lu.", filename.c_str(), offset)); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index f4190b3f1e..731f97a4aa 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem { /// Loads file contents from GCS for a given filename, offset, and length. Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n, - std::vector<char>* out); + char* buffer, size_t* bytes_transferred); std::unique_ptr<AuthProvider> auth_provider_; std::unique_ptr<HttpRequest::Factory> http_request_factory_; diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h index 95a436c622..6b13ac475e 100644 --- a/tensorflow/core/platform/cloud/http_request.h +++ b/tensorflow/core/platform/cloud/http_request.h @@ -101,6 +101,20 @@ class HttpRequest { /// read. Existing content of the vector will be cleared. virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0; + /// \brief Specifies the buffer for receiving the response body. + /// + /// This method should be used when a caller knows the upper bound of the + /// size of the response data. The caller provides a pre-allocated buffer + /// and its size. After the Send() method is called, the + /// GetResultBufferDirectBytesTransferred() method may be used to learn to the + /// number of bytes that were transferred using this method. + virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0; + + /// \brief Returns the number of bytes transferred, when using + /// SetResultBufferDirect(). This method may only be used when using + /// SetResultBufferDirect(). + virtual size_t GetResultBufferDirectBytesTransferred() = 0; + /// \brief Returns the response headers of a completed request. /// /// If the header is not found, returns an empty string. diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h index f65c15dac7..72292dca53 100644 --- a/tensorflow/core/platform/cloud/http_request_fake.h +++ b/tensorflow/core/platform/cloud/http_request_fake.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ #define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_ +#include <algorithm> #include <fstream> #include <string> #include <vector> @@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest { buffer_ = buffer; return Status::OK(); } + Status SetResultBufferDirect(char* buffer, size_t size) override { + direct_result_buffer_ = buffer; + direct_result_buffer_size_ = size; + return Status::OK(); + } + size_t GetResultBufferDirectBytesTransferred() override { + return direct_result_bytes_transferred_; + } Status Send() override { EXPECT_EQ(expected_request_, actual_request()) << "Unexpected HTTP request."; if (buffer_) { - buffer_->insert(buffer_->begin(), response_.c_str(), - response_.c_str() + response_.size()); + buffer_->insert(buffer_->begin(), response_.data(), + response_.data() + response_.size()); + } else if (direct_result_buffer_ != nullptr) { + size_t bytes_to_copy = + std::min<size_t>(direct_result_buffer_size_, response_.size()); + memcpy(direct_result_buffer_, response_.data(), bytes_to_copy); + direct_result_bytes_transferred_ += bytes_to_copy; } return response_status_; } @@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest { } std::vector<char>* buffer_ = nullptr; + char* direct_result_buffer_ = nullptr; + size_t direct_result_buffer_size_ = 0; + size_t direct_result_bytes_transferred_ = 0; string expected_request_; string actual_uri_; string actual_request_; |