aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc53
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.h32
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request_test.cc33
-rw-r--r--tensorflow/core/platform/cloud/file_block_cache.cc21
-rw-r--r--tensorflow/core/platform/cloud/file_block_cache.h9
-rw-r--r--tensorflow/core/platform/cloud/file_block_cache_test.cc181
-rw-r--r--tensorflow/core/platform/cloud/gcs_dns_cache_test.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc28
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h2
-rw-r--r--tensorflow/core/platform/cloud/http_request.h14
-rw-r--r--tensorflow/core/platform/cloud/http_request_fake.h21
11 files changed, 295 insertions, 103 deletions
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index c2533b4314..86943d34a6 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <algorithm>
+
#include "tensorflow/core/platform/cloud/curl_http_request.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -327,6 +329,57 @@ Status CurlHttpRequest::SetResultBuffer(std::vector<char>* out_buffer) {
return Status::OK();
}
+Status CurlHttpRequest::SetResultBufferDirect(char* buffer, size_t size) {
+ CHECK(buffer != nullptr);
+ TF_RETURN_IF_ERROR(CheckInitialized());
+ TF_RETURN_IF_ERROR(CheckNotSent());
+
+ direct_response_ = DirectResponseState{buffer, size, 0};
+
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEDATA,
+ reinterpret_cast<void*>(this));
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_WRITEFUNCTION,
+ &CurlHttpRequest::WriteCallbackDirect);
+ return Status::OK();
+}
+
+size_t CurlHttpRequest::WriteCallbackDirect(const void* ptr, size_t size,
+ size_t nmemb, void* userdata) {
+ CHECK(ptr != nullptr);
+ auto that = reinterpret_cast<CurlHttpRequest*>(userdata);
+ DirectResponseState* state = &that->direct_response_;
+ CHECK(state->buffer_ != nullptr);
+ CHECK(state->bytes_transferred_ <= state->buffer_size_);
+
+ size_t curl_bytes_received = size * nmemb;
+ size_t user_buffer_bytes_available =
+ state->buffer_size_ - state->bytes_transferred_;
+
+ // The HTTP server may send a response body that is longer than what we
+ // expected. We must not use CHECK() for this situation, because that would
+ // imply a code bug (in this client code) where none exists; the violation of
+ // expectations would have been caused by the server, not the client. So we
+ // report a log warning, if an HTTP server is misbehaving.
+ if (curl_bytes_received > user_buffer_bytes_available) {
+ LOG(WARNING) << "The HTTP response body that we received is longer than we "
+ "requested or expected. "
+ << "Total bytes requested: " << state->buffer_size_
+ << " Bytes received (so far) in HTTP response body: "
+ << (state->bytes_transferred_ + curl_bytes_received);
+ }
+
+ size_t bytes_to_copy =
+ std::min<size_t>(curl_bytes_received, user_buffer_bytes_available);
+ memcpy(&state->buffer_[state->bytes_transferred_], ptr, bytes_to_copy);
+ state->bytes_transferred_ += bytes_to_copy;
+ return bytes_to_copy;
+}
+
+size_t CurlHttpRequest::GetResultBufferDirectBytesTransferred() {
+ CHECK(direct_response_.buffer_ != nullptr);
+ return direct_response_.bytes_transferred_;
+}
+
Status CurlHttpRequest::SetTimeouts(uint32 connection, uint32 inactivity,
uint32 total) {
TF_RETURN_IF_ERROR(CheckInitialized());
diff --git a/tensorflow/core/platform/cloud/curl_http_request.h b/tensorflow/core/platform/cloud/curl_http_request.h
index e4c91dac8d..0686b692cb 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.h
+++ b/tensorflow/core/platform/cloud/curl_http_request.h
@@ -103,6 +103,26 @@ class CurlHttpRequest : public HttpRequest {
/// read. Existing content of the vector will be cleared.
Status SetResultBuffer(std::vector<char>* out_buffer) override;
+ /// \brief Specifies the buffer for receiving the response body, when the
+ /// caller knows the maximum size of the response body.
+ ///
+ /// This method allows the caller to receive the response body without an
+ /// additional intermediate buffer allocation and copy. This method should
+ /// be called before calling Send(). After Send() has succeeded, the caller
+ /// should use the GetResultBufferDirectBytesTransferred() method in order
+ /// to learn how many bytes were transferred.
+ ///
+ /// Using this method is mutually exclusive with using SetResultBuffer().
+ Status SetResultBufferDirect(char* buffer, size_t size) override;
+
+ /// \brief Returns the number of bytes (of the response body) that were
+ /// transferred, when using the SetResultBufferDirect() method. The returned
+ /// value will always be less than or equal to the 'size' parameter that
+ /// was passed to SetResultBufferDirect(). If the actual HTTP response body
+ /// was greater than 'size' bytes, then this transfer method will only copy
+ /// the first 'size' bytes, and the rest will be ignored.
+ size_t GetResultBufferDirectBytesTransferred() override;
+
/// \brief Returns the response headers of a completed request.
///
/// If the header is not found, returns an empty string.
@@ -127,6 +147,10 @@ class CurlHttpRequest : public HttpRequest {
/// A write callback in the form which can be accepted by libcurl.
static size_t WriteCallback(const void* ptr, size_t size, size_t nmemb,
void* userdata);
+
+ /// Processes response body content received when using SetResultBufferDirect.
+ static size_t WriteCallbackDirect(const void* ptr, size_t size, size_t nmemb,
+ void* userdata);
/// A read callback in the form which can be accepted by libcurl.
static size_t ReadCallback(void* ptr, size_t size, size_t nmemb,
FILE* userdata);
@@ -150,6 +174,14 @@ class CurlHttpRequest : public HttpRequest {
size_t post_body_read_ = 0;
std::vector<char>* response_buffer_ = nullptr;
+
+ struct DirectResponseState {
+ char* buffer_;
+ size_t buffer_size_;
+ size_t bytes_transferred_;
+ };
+ DirectResponseState direct_response_ = {};
+
CURL* curl_ = nullptr;
curl_slist* curl_headers_ = nullptr;
curl_slist* resolve_list_ = nullptr;
diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc
index 2d3e46edaf..d108849c0f 100644
--- a/tensorflow/core/platform/cloud/curl_http_request_test.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc
@@ -288,6 +288,39 @@ TEST(CurlHttpRequestTest, GetRequest) {
EXPECT_EQ(200, http_request.GetResponseCode());
}
+TEST(CurlHttpRequestTest, GetRequest_Direct) {
+ FakeLibCurl libcurl("get response", 200);
+ CurlHttpRequest http_request(&libcurl);
+ TF_EXPECT_OK(http_request.Init());
+
+ std::vector<char> scratch(100, 0);
+
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.AddAuthBearerHeader("fake-bearer"));
+ TF_EXPECT_OK(http_request.SetRange(100, 199));
+ TF_EXPECT_OK(
+ http_request.SetResultBufferDirect(scratch.data(), scratch.capacity()));
+ TF_EXPECT_OK(http_request.Send());
+
+ string expected_response = "get response";
+ size_t response_bytes_transferred =
+ http_request.GetResultBufferDirectBytesTransferred();
+ EXPECT_EQ(response_bytes_transferred, expected_response.size());
+ EXPECT_EQ(
+ "get response",
+ string(scratch.begin(), scratch.begin() + response_bytes_transferred));
+
+ // Check interactions with libcurl.
+ EXPECT_TRUE(libcurl.is_initialized_);
+ EXPECT_EQ("http://www.testuri.com", libcurl.url_);
+ EXPECT_EQ("100-199", libcurl.range_);
+ EXPECT_EQ("", libcurl.custom_request_);
+ EXPECT_EQ(1, libcurl.headers_->size());
+ EXPECT_EQ("Authorization: Bearer fake-bearer", (*libcurl.headers_)[0]);
+ EXPECT_FALSE(libcurl.is_post_);
+ EXPECT_EQ(200, http_request.GetResponseCode());
+}
+
TEST(CurlHttpRequestTest, GetRequest_Empty) {
FakeLibCurl libcurl("", 200);
CurlHttpRequest http_request(&libcurl);
diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc
index e1afc7b308..e6fa93890f 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache.cc
@@ -123,8 +123,12 @@ Status FileBlockCache::MaybeFetch(const Key& key,
case FetchState::CREATED:
block->state = FetchState::FETCHING;
block->mu.unlock(); // Release the lock while making the API call.
- status.Update(
- block_fetcher_(key.first, key.second, block_size_, &block->data));
+ block->data.clear();
+ block->data.resize(block_size_, 0);
+ size_t bytes_transferred;
+ status.Update(block_fetcher_(key.first, key.second, block_size_,
+ block->data.data(), &bytes_transferred));
+ block->data.resize(bytes_transferred, 0);
block->mu.lock(); // Reacquire the lock immediately afterwards
if (status.ok()) {
downloaded_block = true;
@@ -150,15 +154,15 @@ Status FileBlockCache::MaybeFetch(const Key& key,
}
Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
- out->clear();
+ char* buffer, size_t* bytes_transferred) {
+ *bytes_transferred = 0;
if (n == 0) {
return Status::OK();
}
if (block_size_ == 0 || max_bytes_ == 0) {
// The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks.
- return block_fetcher_(filename, offset, n, out);
+ return block_fetcher_(filename, offset, n, buffer, bytes_transferred);
}
// Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_);
@@ -166,6 +170,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
if (finish < offset + n) {
finish += block_size_;
}
+ size_t total_bytes_transferred = 0;
// Now iterate through the blocks, reading them one at a time.
for (size_t pos = start; pos < finish; pos += block_size_) {
Key key = std::make_pair(filename, pos);
@@ -181,6 +186,7 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
+ *bytes_transferred = total_bytes_transferred;
return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
" at position ", pos, "with data size ",
data.size());
@@ -196,13 +202,16 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
end -= (pos + data.size()) - (offset + n);
}
if (begin < end) {
- out->insert(out->end(), begin, end);
+ size_t bytes_to_copy = end - begin;
+ memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
+ total_bytes_transferred += bytes_to_copy;
}
if (data.size() < block_size_) {
// The block was a partial block and thus signals EOF at its upper bound.
break;
}
}
+ *bytes_transferred = total_bytes_transferred;
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/file_block_cache.h b/tensorflow/core/platform/cloud/file_block_cache.h
index 36dbf9db83..74e792a625 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.h
+++ b/tensorflow/core/platform/cloud/file_block_cache.h
@@ -43,8 +43,9 @@ class FileBlockCache {
/// cache is constructed. The returned Status should be OK as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
- typedef std::function<Status(const string&, size_t, size_t,
- std::vector<char>*)>
+ typedef std::function<Status(const string& filename, size_t offset,
+ size_t buffer_size, char* buffer,
+ size_t* bytes_transferred)>
BlockFetcher;
FileBlockCache(size_t block_size, size_t max_bytes, uint64 max_staleness,
@@ -83,8 +84,8 @@ class FileBlockCache {
/// placed in `out`.
/// 4) OK otherwise (i.e. the read succeeded, and at least one byte was placed
/// in `out`).
- Status Read(const string& filename, size_t offset, size_t n,
- std::vector<char>* out);
+ Status Read(const string& filename, size_t offset, size_t n, char* buffer,
+ size_t* bytes_transferred);
/// Remove all cached blocks for `filename`.
void RemoveFile(const string& filename) LOCKS_EXCLUDED(mu_);
diff --git a/tensorflow/core/platform/cloud/file_block_cache_test.cc b/tensorflow/core/platform/cloud/file_block_cache_test.cc
index 081b32af64..12ad44011f 100644
--- a/tensorflow/core/platform/cloud/file_block_cache_test.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache_test.cc
@@ -25,6 +25,18 @@ limitations under the License.
namespace tensorflow {
namespace {
+Status ReadCache(FileBlockCache* cache, const string& filename, size_t offset,
+ size_t n, std::vector<char>* out) {
+ out->clear();
+ out->resize(n, 0);
+ size_t bytes_transferred = 0;
+ Status status =
+ cache->Read(filename, offset, n, out->data(), &bytes_transferred);
+ EXPECT_LE(bytes_transferred, n);
+ out->resize(bytes_transferred, n);
+ return status;
+}
+
TEST(FileBlockCacheTest, PassThrough) {
const string want_filename = "foo/bar";
const size_t want_offset = 42;
@@ -32,12 +44,13 @@ TEST(FileBlockCacheTest, PassThrough) {
int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset,
- size_t got_n, std::vector<char>* out) {
+ size_t got_n, char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n);
calls++;
- out->resize(got_n, 'x');
+ memset(buffer, 'x', got_n);
+ *bytes_transferred = got_n;
return Status::OK();
};
// If block_size, max_bytes, or both are zero, the cache is a pass-through.
@@ -45,11 +58,11 @@ TEST(FileBlockCacheTest, PassThrough) {
FileBlockCache cache2(0, 1, 0, fetcher);
FileBlockCache cache3(0, 0, 0, fetcher);
std::vector<char> out;
- TF_EXPECT_OK(cache1.Read(want_filename, want_offset, want_n, &out));
+ TF_EXPECT_OK(ReadCache(&cache1, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 1);
- TF_EXPECT_OK(cache2.Read(want_filename, want_offset, want_n, &out));
+ TF_EXPECT_OK(ReadCache(&cache2, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 2);
- TF_EXPECT_OK(cache3.Read(want_filename, want_offset, want_n, &out));
+ TF_EXPECT_OK(ReadCache(&cache3, want_filename, want_offset, want_n, &out));
EXPECT_EQ(calls, 3);
}
@@ -63,13 +76,13 @@ TEST(FileBlockCacheTest, BlockAlignment) {
}
// The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
if (offset < buf.size()) {
- if (offset + n > buf.size()) {
- out->insert(out->end(), buf.begin() + offset, buf.end());
- } else {
- out->insert(out->end(), buf.begin() + offset, buf.begin() + offset + n);
- }
+ size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
+ memcpy(buffer, buf.data() + offset, bytes_to_copy);
+ *bytes_transferred = bytes_to_copy;
+ } else {
+ *bytes_transferred = 0;
}
return Status::OK();
};
@@ -80,7 +93,7 @@ TEST(FileBlockCacheTest, BlockAlignment) {
for (size_t offset = 0; offset < 10; offset++) {
for (size_t n = block_size - 2; n <= block_size + 2; n++) {
std::vector<char> got;
- TF_EXPECT_OK(cache.Read("", offset, n, &got));
+ TF_EXPECT_OK(ReadCache(&cache, "", offset, n, &got));
// Verify the size of the read.
if (offset + n <= size) {
// Expect a full read.
@@ -108,24 +121,27 @@ TEST(FileBlockCacheTest, CacheHits) {
const size_t block_size = 16;
std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
- size_t n, std::vector<char>* out) {
+ size_t n, char* buffer,
+ size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset);
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
return Status::OK();
};
const uint32 block_count = 256;
FileBlockCache cache(block_size, block_count * block_size, 0, fetcher);
std::vector<char> out;
+ out.resize(block_count, 0);
// The cache has space for `block_count` blocks. The loop with i = 0 should
// fill the cache, and the loop with i = 1 should be all cache hits. The
// fetcher checks that it is called once and only once for each offset (to
// fetch the corresponding block).
for (int i = 0; i < 2; i++) {
for (int j = 0; j < block_count; j++) {
- TF_EXPECT_OK(cache.Read("", block_size * j, block_size, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size * j, block_size, &out));
}
}
}
@@ -138,36 +154,39 @@ TEST(FileBlockCacheTest, OutOfRange) {
bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
+ size_t bytes_to_copy = 0;
if (offset == 0) {
// The first block (16 bytes) of the file.
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ bytes_to_copy = n;
first_block = true;
} else if (offset == block_size) {
// The second block (8 bytes) of the file.
- out->resize(file_size - block_size, 'x');
+ bytes_to_copy = file_size - block_size;
+ memset(buffer, 'x', bytes_to_copy);
second_block = true;
}
+ *bytes_transferred = bytes_to_copy;
return Status::OK();
};
FileBlockCache cache(block_size, block_size, 0, fetcher);
std::vector<char> out;
// Reading the first 16 bytes should be fine.
- TF_EXPECT_OK(cache.Read("", 0, block_size, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size, &out));
EXPECT_TRUE(first_block);
EXPECT_EQ(out.size(), block_size);
// Reading at offset file_size + 4 will read the second block (since the read
// at file_size + 4 = 28 will be aligned to an offset of 16) but will return
// OutOfRange because the offset is past the end of the 24-byte file.
- Status status = cache.Read("", file_size + 4, 4, &out);
+ Status status = ReadCache(&cache, "", file_size + 4, 4, &out);
EXPECT_EQ(status.code(), error::OUT_OF_RANGE);
EXPECT_TRUE(second_block);
- EXPECT_EQ(out.size(), 0);
// Reading the second full block will return 8 bytes, from a cache hit.
second_block = false;
- TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_FALSE(second_block);
EXPECT_EQ(out.size(), file_size - block_size);
}
@@ -178,20 +197,22 @@ TEST(FileBlockCacheTest, Inconsistent) {
const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
- out->resize(1, 'x');
+ EXPECT_GE(n, 1);
+ memset(buffer, 'x', 1);
+ *bytes_transferred = 1;
return Status::OK();
};
FileBlockCache cache(block_size, 2 * block_size, 0, fetcher);
std::vector<char> out;
// Read the second block; this should yield an OK status and a single byte.
- TF_EXPECT_OK(cache.Read("", block_size, block_size, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size, block_size, &out));
EXPECT_EQ(out.size(), 1);
// Now read the first block; this should yield an INTERNAL error because we
// had already cached a partial block at a later position.
- Status status = cache.Read("", 0, block_size, &out);
+ Status status = ReadCache(&cache, "", 0, block_size, &out);
EXPECT_EQ(status.code(), error::INTERNAL);
}
@@ -199,14 +220,16 @@ TEST(FileBlockCacheTest, LRU) {
const size_t block_size = 16;
std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
- size_t n, std::vector<char>* out) {
+ size_t n, char* buffer,
+ size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) {
EXPECT_EQ(offset, calls.front());
calls.pop_front();
}
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
return Status::OK();
};
const uint32 block_count = 2;
@@ -216,38 +239,39 @@ TEST(FileBlockCacheTest, LRU) {
// fetcher calls that the cache makes.
calls.push_back(0);
// Cache miss - drains an element from `calls`.
- TF_EXPECT_OK(cache.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Cache hit - does not drain an element from `calls`.
- TF_EXPECT_OK(cache.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
calls.push_back(block_size);
// Cache miss followed by cache hit.
- TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
- TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
calls.push_back(2 * block_size);
// Cache miss followed by cache hit. Causes eviction of LRU element.
- TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
- TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// LRU element was at offset 0. Cache miss.
calls.push_back(0);
- TF_EXPECT_OK(cache.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
// Element at 2 * block_size is still in cache, and this read should update
// its position in the LRU list so it doesn't get evicted by the next read.
- TF_EXPECT_OK(cache.Read("", 2 * block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 2 * block_size, 1, &out));
// Element at block_size was evicted. Reading this element will also cause
// the LRU element (at 0) to be evicted.
calls.push_back(block_size);
- TF_EXPECT_OK(cache.Read("", block_size, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size, 1, &out));
// Element at 0 was evicted again.
calls.push_back(0);
- TF_EXPECT_OK(cache.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, 1, &out));
}
TEST(FileBlockCacheTest, MaxStaleness) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
calls++;
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
return Status::OK();
};
std::vector<char> out;
@@ -256,14 +280,14 @@ TEST(FileBlockCacheTest, MaxStaleness) {
// expected.
FileBlockCache cache1(8, 16, 2 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block.
- TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Now advance the clock one second at a time and redo the read. The call
// count should advance every 3 seconds (i.e. every time the staleness is
// greater than 2).
for (int i = 1; i <= 10; i++) {
env->SetNowSeconds(i + 1);
- TF_EXPECT_OK(cache1.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache1, "", 0, 1, &out));
EXPECT_EQ(calls, 1 + i / 3);
}
// Now create a cache with max staleness of 0, and verify that it also works
@@ -272,27 +296,27 @@ TEST(FileBlockCacheTest, MaxStaleness) {
env->SetNowSeconds(0);
FileBlockCache cache2(8, 16, 0 /* max staleness */, fetcher, env.get());
// Execute the first read to load the block.
- TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
// Advance the clock by a huge amount and verify that the cached block is
// used to satisfy the read.
env->SetNowSeconds(365 * 24 * 60 * 60); // ~1 year, just for fun.
- TF_EXPECT_OK(cache2.Read("", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache2, "", 0, 1, &out));
EXPECT_EQ(calls, 1);
}
TEST(FileBlockCacheTest, RemoveFile) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) {
// The first block is lower case and all subsequent blocks are upper case.
c = toupper(c);
}
- out->clear();
- out->resize(n, c);
+ memset(buffer, c, n);
+ *bytes_transferred = n;
return Status::OK();
};
// This cache has space for 4 blocks; we'll read from two files.
@@ -304,41 +328,41 @@ TEST(FileBlockCacheTest, RemoveFile) {
std::vector<char> A(n, 'A');
std::vector<char> B(n, 'B');
// Fill the cache.
- TF_EXPECT_OK(cache.Read("a", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 1);
- TF_EXPECT_OK(cache.Read("a", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 2);
- TF_EXPECT_OK(cache.Read("b", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
EXPECT_EQ(calls, 3);
- TF_EXPECT_OK(cache.Read("b", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// All four blocks should be in the cache now.
- TF_EXPECT_OK(cache.Read("a", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
- TF_EXPECT_OK(cache.Read("a", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
- TF_EXPECT_OK(cache.Read("b", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
- TF_EXPECT_OK(cache.Read("b", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// Remove the blocks from "a".
cache.RemoveFile("a");
// Both blocks from "b" should still be there.
- TF_EXPECT_OK(cache.Read("b", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, n, &out));
EXPECT_EQ(out, b);
- TF_EXPECT_OK(cache.Read("b", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 8, n, &out));
EXPECT_EQ(out, B);
EXPECT_EQ(calls, 4);
// The blocks from "a" should not be there.
- TF_EXPECT_OK(cache.Read("a", 0, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 0, n, &out));
EXPECT_EQ(out, a);
EXPECT_EQ(calls, 5);
- TF_EXPECT_OK(cache.Read("a", 8, n, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 8, n, &out));
EXPECT_EQ(out, A);
EXPECT_EQ(calls, 6);
}
@@ -346,10 +370,10 @@ TEST(FileBlockCacheTest, RemoveFile) {
TEST(FileBlockCacheTest, Prune) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
calls++;
- out->clear();
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
return Status::OK();
};
std::vector<char> out;
@@ -360,20 +384,20 @@ TEST(FileBlockCacheTest, Prune) {
FileBlockCache cache(8, 32, 1 /* max staleness */, fetcher, env.get());
// Read three blocks into the cache, and advance the timestamp by one second
// with each read. Start with a block of "a" at the current timestamp `now`.
- TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
// Now load a block of a different file "b" at timestamp `now` + 1
env->SetNowSeconds(now + 1);
- TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
// Now load a different block of file "a" at timestamp `now` + 1. When the
// first block of "a" expires, this block should also be removed because it
// also belongs to file "a".
- TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
// Ensure that all blocks are in the cache (i.e. reads are cache hits).
EXPECT_EQ(cache.CacheSize(), 24);
EXPECT_EQ(calls, 3);
- TF_EXPECT_OK(cache.Read("a", 0, 1, &out));
- TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
- TF_EXPECT_OK(cache.Read("a", 8, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "a", 8, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake timestamp so that "a" becomes stale via its first block.
env->SetNowSeconds(now + 2);
@@ -389,7 +413,7 @@ TEST(FileBlockCacheTest, Prune) {
// There should be one block left in the cache, and it should be the first
// block of "b".
EXPECT_EQ(cache.CacheSize(), 8);
- TF_EXPECT_OK(cache.Read("b", 0, 1, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "b", 0, 1, &out));
EXPECT_EQ(calls, 3);
// Advance the fake time to `now` + 3, at which point "b" becomes stale.
env->SetNowSeconds(now + 3);
@@ -409,14 +433,14 @@ TEST(FileBlockCacheTest, ParallelReads) {
const int callers = 4;
BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug.
return errors::FailedPrecondition("desired concurrency not reached");
}
- out->clear();
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
return Status::OK();
};
const int block_size = 8;
@@ -426,7 +450,8 @@ TEST(FileBlockCacheTest, ParallelReads) {
threads.emplace_back(
Env::Default()->StartThread({}, "caller", [&cache, i, block_size]() {
std::vector<char> out;
- TF_EXPECT_OK(cache.Read("a", i * block_size, block_size, &out));
+ TF_EXPECT_OK(
+ ReadCache(&cache, "a", i * block_size, block_size, &out));
std::vector<char> x(block_size, 'x');
EXPECT_EQ(out, x);
}));
@@ -443,11 +468,12 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
Notification notification;
auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
+ char* buffer, size_t* bytes_transferred) {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0);
num_requests++;
- out->resize(n, 'x');
+ memset(buffer, 'x', n);
+ *bytes_transferred = n;
notification.Notify();
// Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
@@ -458,17 +484,16 @@ TEST(FileBlockCacheTest, CoalesceConcurrentReads) {
std::unique_ptr<Thread> concurrent(
Env::Default()->StartThread({}, "concurrent", [&cache] {
std::vector<char> out;
- TF_EXPECT_OK(cache.Read("", 0, block_size / 2, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", 0, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
}));
EXPECT_TRUE(WaitForNotificationWithTimeout(&notification, 10000))
<< "Timeout waiting for concurrent thread to start.";
std::vector<char> out;
- TF_EXPECT_OK(cache.Read("", block_size / 2, block_size / 2, &out));
+ TF_EXPECT_OK(ReadCache(&cache, "", block_size / 2, block_size / 2, &out));
EXPECT_EQ(out.size(), block_size / 2);
EXPECT_EQ(1, num_requests);
}
-
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
index 2c3819f1e2..c96d364228 100644
--- a/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_dns_cache_test.cc
@@ -58,6 +58,10 @@ class TestHttpRequest : public HttpRequest {
Status SetResultBuffer(std::vector<char>* out_buffer) override {
return Status::OK();
}
+ Status SetResultBufferDirect(char* buffer, size_t size) override {
+ return Status::OK();
+ }
+ size_t GetResultBufferDirectBytesTransferred() override { return 0; }
string GetResponseHeader(const string& name) const override { return ""; }
uint64 GetResponseCode() const override { return 0; }
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index a183fe6fa8..ec66ab01d6 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -285,11 +285,11 @@ class GcsRandomAccessFile : public RandomAccessFile {
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
*result = StringPiece();
- std::vector<char> out;
- TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, &out));
- std::memcpy(scratch, out.data(), std::min(out.size(), n));
- *result = StringPiece(scratch, std::min(out.size(), n));
- if (result->size() < n) {
+ size_t bytes_transferred;
+ TF_RETURN_IF_ERROR(file_block_cache_->Read(filename_, offset, n, scratch,
+ &bytes_transferred));
+ *result = StringPiece(scratch, bytes_transferred);
+ if (bytes_transferred < n) {
// This is not an error per se. The RandomAccessFile interface expects
// that Read returns OutOfRange if fewer bytes were read than requested.
return errors::OutOfRange("EOF reached, ", result->size(),
@@ -721,15 +721,17 @@ std::unique_ptr<FileBlockCache> GcsFileSystem::MakeFileBlockCache(
std::unique_ptr<FileBlockCache> file_block_cache(
new FileBlockCache(block_size, max_bytes, max_staleness,
[this](const string& filename, size_t offset, size_t n,
- std::vector<char>* out) {
- return LoadBufferFromGCS(filename, offset, n, out);
+ char* buffer, size_t* bytes_transferred) {
+ return LoadBufferFromGCS(filename, offset, n, buffer,
+ bytes_transferred);
}));
return file_block_cache;
}
// A helper function to actually read the data from GCS.
Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
- size_t n, std::vector<char>* out) {
+ size_t n, char* buffer,
+ size_t* bytes_transferred) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(filename, false, &bucket, &object));
@@ -739,21 +741,23 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
request->SetUri(strings::StrCat("https://", kStorageHost, "/", bucket,
"/", request->EscapeString(object))));
TF_RETURN_IF_ERROR(request->SetRange(offset, offset + n - 1));
- TF_RETURN_IF_ERROR(request->SetResultBuffer(out));
+ TF_RETURN_IF_ERROR(request->SetResultBufferDirect(buffer, n));
TF_RETURN_IF_ERROR(
request->SetTimeouts(timeouts_.connect, timeouts_.idle, timeouts_.read));
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object);
+ size_t bytes_read = request->GetResultBufferDirectBytesTransferred();
+ *bytes_transferred = bytes_read;
VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
- << offset << " of size: " << out->size();
+ << offset << " of size: " << bytes_read;
- if (out->size() < block_size()) {
+ if (bytes_read < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
if (stat_cache_->Lookup(filename, &stat)) {
- if (offset + out->size() < stat.length) {
+ if (offset + bytes_read < stat.length) {
return errors::Internal(strings::Printf(
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index f4190b3f1e..731f97a4aa 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -177,7 +177,7 @@ class GcsFileSystem : public FileSystem {
/// Loads file contents from GCS for a given filename, offset, and length.
Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n,
- std::vector<char>* out);
+ char* buffer, size_t* bytes_transferred);
std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 95a436c622..6b13ac475e 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -101,6 +101,20 @@ class HttpRequest {
/// read. Existing content of the vector will be cleared.
virtual Status SetResultBuffer(std::vector<char>* out_buffer) = 0;
+ /// \brief Specifies the buffer for receiving the response body.
+ ///
+ /// This method should be used when a caller knows the upper bound of the
+ /// size of the response data. The caller provides a pre-allocated buffer
+ /// and its size. After the Send() method is called, the
+ /// GetResultBufferDirectBytesTransferred() method may be used to learn to the
+ /// number of bytes that were transferred using this method.
+ virtual Status SetResultBufferDirect(char* buffer, size_t size) = 0;
+
+ /// \brief Returns the number of bytes transferred, when using
+ /// SetResultBufferDirect(). This method may only be used when using
+ /// SetResultBufferDirect().
+ virtual size_t GetResultBufferDirectBytesTransferred() = 0;
+
/// \brief Returns the response headers of a completed request.
///
/// If the header is not found, returns an empty string.
diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h
index f65c15dac7..72292dca53 100644
--- a/tensorflow/core/platform/cloud/http_request_fake.h
+++ b/tensorflow/core/platform/cloud/http_request_fake.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
+#include <algorithm>
#include <fstream>
#include <string>
#include <vector>
@@ -130,12 +131,25 @@ class FakeHttpRequest : public CurlHttpRequest {
buffer_ = buffer;
return Status::OK();
}
+ Status SetResultBufferDirect(char* buffer, size_t size) override {
+ direct_result_buffer_ = buffer;
+ direct_result_buffer_size_ = size;
+ return Status::OK();
+ }
+ size_t GetResultBufferDirectBytesTransferred() override {
+ return direct_result_bytes_transferred_;
+ }
Status Send() override {
EXPECT_EQ(expected_request_, actual_request())
<< "Unexpected HTTP request.";
if (buffer_) {
- buffer_->insert(buffer_->begin(), response_.c_str(),
- response_.c_str() + response_.size());
+ buffer_->insert(buffer_->begin(), response_.data(),
+ response_.data() + response_.size());
+ } else if (direct_result_buffer_ != nullptr) {
+ size_t bytes_to_copy =
+ std::min<size_t>(direct_result_buffer_size_, response_.size());
+ memcpy(direct_result_buffer_, response_.data(), bytes_to_copy);
+ direct_result_bytes_transferred_ += bytes_to_copy;
}
return response_status_;
}
@@ -178,6 +192,9 @@ class FakeHttpRequest : public CurlHttpRequest {
}
std::vector<char>* buffer_ = nullptr;
+ char* direct_result_buffer_ = nullptr;
+ size_t direct_result_buffer_size_ = 0;
+ size_t direct_result_bytes_transferred_ = 0;
string expected_request_;
string actual_uri_;
string actual_request_;