aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexey Surkov <surkov@google.com>2017-05-11 11:09:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-11 11:45:35 -0700
commit31415474123c27197c8553314366425ff2331d36 (patch)
tree755d7c6e1872c018dc284bbdfd652608a3b1c986
parent1f69227b32d9505439aa132667922091be5fca7d (diff)
Automatically abort stuck transmissions.
HttpRequest now uses a progress callback from libcurl to detect and abort transmissions which haven't made any progress for a significant time. PiperOrigin-RevId: 155770703
-rw-r--r--tensorflow/core/platform/cloud/http_request.cc47
-rw-r--r--tensorflow/core/platform/cloud/http_request.h20
-rw-r--r--tensorflow/core/platform/cloud/http_request_test.cc77
3 files changed, 142 insertions, 2 deletions
diff --git a/tensorflow/core/platform/cloud/http_request.cc b/tensorflow/core/platform/cloud/http_request.cc
index 825741f614..2d0141e50e 100644
--- a/tensorflow/core/platform/cloud/http_request.cc
+++ b/tensorflow/core/platform/cloud/http_request.cc
@@ -35,6 +35,10 @@ constexpr uint32 kRequestTimeoutSeconds = 3600; // 1 hour
// Timeout for the connection phase.
constexpr uint32 kConnectTimeoutSeconds = 120; // 2 minutes
+// The maximum period of request inactivity, after which the request
+// is terminated.
+constexpr uint64 kInactivityTimeoutSeconds = 60; // 1 minute
+
// Proxy to the real libcurl implementation.
class LibCurlProxy : public LibCurl {
public:
@@ -75,6 +79,13 @@ class LibCurlProxy : public LibCurl {
return ::curl_easy_setopt(curl, option, param);
}
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ int (*param)(void* clientp, curl_off_t dltotal,
+ curl_off_t dlnow, curl_off_t ultotal,
+ curl_off_t ulnow)) override {
+ return ::curl_easy_setopt(curl, option, param);
+ }
+
CURLcode curl_easy_perform(CURL* curl) override {
return ::curl_easy_perform(curl);
}
@@ -111,7 +122,8 @@ class LibCurlProxy : public LibCurl {
HttpRequest::HttpRequest() : HttpRequest(LibCurlProxy::Load()) {}
-HttpRequest::HttpRequest(LibCurl* libcurl) : libcurl_(libcurl) {
+HttpRequest::HttpRequest(LibCurl* libcurl, Env* env)
+ : libcurl_(libcurl), env_(env) {
default_response_buffer_.reserve(CURL_MAX_WRITE_SIZE);
}
@@ -152,6 +164,12 @@ Status HttpRequest::Init() {
libcurl_->curl_easy_setopt(curl_, CURLOPT_HTTP_VERSION,
CURL_HTTP_VERSION_2_0);
+ // Set up the progress meter.
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_NOPROGRESS, 0ULL);
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_XFERINFODATA, this);
+ libcurl_->curl_easy_setopt(curl_, CURLOPT_XFERINFOFUNCTION,
+ &HttpRequest::ProgressCallback);
+
// If response buffer is not set, libcurl will print results to stdout,
// so we always set it.
is_initialized_ = true;
@@ -470,4 +488,31 @@ string HttpRequest::GetResponseHeader(const string& name) const {
uint64 HttpRequest::GetResponseCode() const { return response_code_; }
+// Cancels the transmission if no progress has been made for too long.
+int HttpRequest::ProgressCallback(void* this_object, curl_off_t dltotal,
+ curl_off_t dlnow, curl_off_t ultotal,
+ curl_off_t ulnow) {
+ auto that = reinterpret_cast<HttpRequest*>(this_object);
+ const auto now = that->env_->NowSeconds();
+ const auto current_progress = dlnow + ulnow;
+ if (that->last_progress_timestamp_ == 0 ||
+ current_progress > that->last_progress_bytes_) {
+ // This is the first time the callback is called or some progress
+ // was made since the last tick.
+ that->last_progress_timestamp_ = now;
+ that->last_progress_bytes_ = current_progress;
+ return 0;
+ }
+
+ if (now - that->last_progress_timestamp_ > kInactivityTimeoutSeconds) {
+ LOG(ERROR) << "The transmission has been stuck at " << current_progress
+ << " bytes for " << now - that->last_progress_timestamp_
+ << " seconds and will be aborted.";
+ return 1; // Will abort the request.
+ }
+
+ // No progress was made since the last call, but we should wait a bit longer.
+ return 0;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 5365c45ca9..afcbb9f35c 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
@@ -50,7 +51,9 @@ class HttpRequest {
};
HttpRequest();
- explicit HttpRequest(LibCurl* libcurl);
+ explicit HttpRequest(LibCurl* libcurl)
+ : HttpRequest(libcurl, Env::Default()) {}
+ HttpRequest(LibCurl* libcurl, Env* env);
virtual ~HttpRequest();
virtual Status Init();
@@ -123,11 +126,16 @@ class HttpRequest {
/// A header callback in the form which can be accepted by libcurl.
static size_t HeaderCallback(const void* ptr, size_t size, size_t nmemb,
void* this_object);
+ /// A progress meter callback in the form which can be accepted by libcurl.
+ static int ProgressCallback(void* this_object, curl_off_t dltotal,
+ curl_off_t dlnow, curl_off_t ultotal,
+ curl_off_t ulnow);
Status CheckInitialized() const;
Status CheckMethodNotSet() const;
Status CheckNotSent() const;
LibCurl* libcurl_;
+ Env* env_;
FILE* put_body_ = nullptr;
@@ -144,6 +152,12 @@ class HttpRequest {
std::unordered_map<string, string> response_headers_;
uint64 response_code_ = 0;
+ // The timestamp of the last activity related to the request execution, in
+ // seconds since epoch.
+ uint64 last_progress_timestamp_ = 0;
+ // The last progress in terms of bytes transmitted.
+ curl_off_t last_progress_bytes_ = 0;
+
// Members to enforce the usage flow.
bool is_initialized_ = false;
bool is_uri_set_ = false;
@@ -173,6 +187,10 @@ class LibCurl {
virtual CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
size_t (*param)(const void*, size_t, size_t,
void*)) = 0;
+ virtual CURLcode curl_easy_setopt(
+ CURL* curl, CURLoption option,
+ int (*param)(void* clientp, curl_off_t dltotal, curl_off_t dlnow,
+ curl_off_t ultotal, curl_off_t ulnow)) = 0;
virtual CURLcode curl_easy_perform(CURL* curl) = 0;
virtual CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
uint64* value) = 0;
diff --git a/tensorflow/core/platform/cloud/http_request_test.cc b/tensorflow/core/platform/cloud/http_request_test.cc
index b918a3a8cd..6d66dfdee1 100644
--- a/tensorflow/core/platform/cloud/http_request_test.cc
+++ b/tensorflow/core/platform/cloud/http_request_test.cc
@@ -25,12 +25,27 @@ namespace {
const string kTestContent = "random original scratch content";
+class FakeEnv : public EnvWrapper {
+ public:
+ FakeEnv() : EnvWrapper(Env::Default()) {}
+
+ uint64 NowSeconds() override { return now_; }
+ uint64 now_ = 10000;
+};
+
// A fake proxy that pretends to be libcurl.
class FakeLibCurl : public LibCurl {
public:
FakeLibCurl(const string& response_content, uint64 response_code)
: response_content_(response_content), response_code_(response_code) {}
FakeLibCurl(const string& response_content, uint64 response_code,
+ std::vector<std::tuple<uint64, curl_off_t>> progress_ticks,
+ FakeEnv* env)
+ : response_content_(response_content),
+ response_code_(response_code),
+ progress_ticks_(std::move(progress_ticks)),
+ env_(env) {}
+ FakeLibCurl(const string& response_content, uint64 response_code,
const std::vector<string>& response_headers)
: response_content_(response_content),
response_code_(response_code),
@@ -86,6 +101,9 @@ class FakeLibCurl : public LibCurl {
case CURLOPT_READDATA:
read_data_ = reinterpret_cast<FILE*>(param);
break;
+ case CURLOPT_XFERINFODATA:
+ progress_data_ = param;
+ break;
default:
break;
}
@@ -112,6 +130,13 @@ class FakeLibCurl : public LibCurl {
}
return CURLE_OK;
}
+ CURLcode curl_easy_setopt(CURL* curl, CURLoption option,
+ int (*param)(void* clientp, curl_off_t dltotal,
+ curl_off_t dlnow, curl_off_t ultotal,
+ curl_off_t ulnow)) override {
+ progress_callback_ = param;
+ return CURLE_OK;
+ }
CURLcode curl_easy_perform(CURL* curl) override {
if (read_data_) {
char buffer[3];
@@ -134,6 +159,12 @@ class FakeLibCurl : public LibCurl {
strncpy(error_buffer_, curl_easy_perform_error_message_.c_str(),
curl_easy_perform_error_message_.size() + 1);
}
+ for (const auto& tick : progress_ticks_) {
+ env_->now_ = std::get<0>(tick);
+ if (progress_callback_(progress_data_, 0, std::get<1>(tick), 0, 0)) {
+ return CURLE_ABORTED_BY_CALLBACK;
+ }
+ }
return curl_easy_perform_result_;
}
CURLcode curl_easy_getinfo(CURL* curl, CURLINFO info,
@@ -212,10 +243,17 @@ class FakeLibCurl : public LibCurl {
FILE* read_data_ = nullptr;
size_t (*read_callback_)(void* ptr, size_t size, size_t nmemb,
FILE* userdata) = &fread;
+ int (*progress_callback_)(void* clientp, curl_off_t dltotal, curl_off_t dlnow,
+ curl_off_t ultotal, curl_off_t ulnow) = nullptr;
+ void* progress_data_ = nullptr;
// Outcome of performing the request.
string posted_content_;
CURLcode curl_easy_perform_result_ = CURLE_OK;
string curl_easy_perform_error_message_;
+ // A vector of <timestamp, progress in bytes> pairs that represent the
+ // progress of a transmission.
+ std::vector<std::tuple<uint64, curl_off_t>> progress_ticks_;
+ FakeEnv* env_ = nullptr;
};
TEST(HttpRequestTest, GetRequest) {
@@ -547,5 +585,44 @@ TEST(HttpRequestTest, ErrorReturnsNoResponse) {
EXPECT_EQ("", string(scratch.begin(), scratch.end()));
}
+TEST(HttpRequestTest, ProgressIsOk) {
+ // Imitate a steady progress.
+ FakeEnv env;
+ FakeLibCurl libcurl(
+ "test", 200,
+ {
+ std::make_tuple(100, 0) /* timestamp 100, 0 bytes */,
+ std::make_tuple(110, 0) /* timestamp 110, 0 bytes */,
+ std::make_tuple(200, 100) /* timestamp 200, 100 bytes */
+ },
+ &env);
+ HttpRequest http_request(&libcurl, &env);
+ TF_EXPECT_OK(http_request.Init());
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ TF_EXPECT_OK(http_request.Send());
+}
+
+TEST(HttpRequestTest, ProgressIsStuck) {
+ // Imitate a transmission that got stuck for more than a minute.
+ FakeEnv env;
+ FakeLibCurl libcurl(
+ "test", 200,
+ {
+ std::make_tuple(100, 10) /* timestamp 100, 10 bytes */,
+ std::make_tuple(130, 10) /* timestamp 130, 10 bytes */,
+ std::make_tuple(170, 10) /* timestamp 170, 10 bytes */
+ },
+ &env);
+ HttpRequest http_request(&libcurl, &env);
+ TF_EXPECT_OK(http_request.Init());
+ TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
+ auto status = http_request.Send();
+ EXPECT_EQ(error::UNAVAILABLE, status.code());
+ EXPECT_EQ(
+ "Error executing an HTTP request (HTTP response code 200, "
+ "error code 42, error message '')",
+ status.error_message());
+}
+
} // namespace
} // namespace tensorflow