aboutsummaryrefslogtreecommitdiffhomepage
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/core/client_channel/BUILD11
-rw-r--r--test/core/client_channel/parse_address_with_named_scope_id_test.cc126
-rw-r--r--test/core/end2end/fixtures/h2_full+trace.cc9
-rw-r--r--test/core/end2end/fixtures/h2_sockpair+trace.cc9
-rw-r--r--test/core/fling/BUILD12
-rw-r--r--test/core/iomgr/BUILD20
-rw-r--r--test/core/iomgr/resolve_address_posix_test.cc81
-rw-r--r--test/core/iomgr/tcp_server_posix_test.cc5
-rw-r--r--test/core/memory_usage/BUILD8
-rw-r--r--test/cpp/end2end/client_interceptors_end2end_test.cc299
-rw-r--r--test/cpp/end2end/interceptors_util.cc10
-rw-r--r--test/cpp/end2end/interceptors_util.h3
-rw-r--r--test/cpp/end2end/server_interceptors_end2end_test.cc81
-rw-r--r--test/cpp/microbenchmarks/bm_call_create.cc2
-rw-r--r--test/cpp/qps/client.h118
-rw-r--r--test/cpp/qps/client_callback.cc201
-rwxr-xr-xtest/cpp/qps/gen_build_yaml.py2
17 files changed, 891 insertions, 106 deletions
diff --git a/test/core/client_channel/BUILD b/test/core/client_channel/BUILD
index 04485f5240..57e5191af4 100644
--- a/test/core/client_channel/BUILD
+++ b/test/core/client_channel/BUILD
@@ -44,6 +44,17 @@ grpc_cc_test(
)
grpc_cc_test(
+ name = "parse_address_with_named_scope_id_test",
+ srcs = ["parse_address_with_named_scope_id_test.cc"],
+ language = "C++",
+ deps = [
+ "//:gpr",
+ "//:grpc",
+ "//test/core/util:grpc_test_util",
+ ],
+)
+
+grpc_cc_test(
name = "uri_parser_test",
srcs = ["uri_parser_test.cc"],
language = "C++",
diff --git a/test/core/client_channel/parse_address_with_named_scope_id_test.cc b/test/core/client_channel/parse_address_with_named_scope_id_test.cc
new file mode 100644
index 0000000000..bfafa74517
--- /dev/null
+++ b/test/core/client_channel/parse_address_with_named_scope_id_test.cc
@@ -0,0 +1,126 @@
+/*
+ *
+ * Copyright 2017 gRPC authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+
+#include "src/core/ext/filters/client_channel/parse_address.h"
+#include "src/core/lib/iomgr/sockaddr.h"
+#include "src/core/lib/iomgr/socket_utils.h"
+
+#include <net/if.h>
+#include <string.h>
+#ifdef GRPC_HAVE_UNIX_SOCKET
+#include <sys/un.h>
+#endif
+
+#include <grpc/grpc.h>
+#include <grpc/support/log.h>
+#include <grpc/support/string_util.h>
+
+#include "src/core/lib/gpr/host_port.h"
+#include "src/core/lib/iomgr/exec_ctx.h"
+#include "src/core/lib/iomgr/socket_utils.h"
+#include "test/core/util/test_config.h"
+
+static void test_grpc_parse_ipv6_parity_with_getaddrinfo(
+ const char* target, const struct sockaddr_in6 result_from_getaddrinfo) {
+ // Get the sockaddr that gRPC's ipv6 resolver resolves this too.
+ grpc_core::ExecCtx exec_ctx;
+ grpc_uri* uri = grpc_uri_parse(target, 0);
+ grpc_resolved_address addr;
+ GPR_ASSERT(1 == grpc_parse_ipv6(uri, &addr));
+ grpc_sockaddr_in6* result_from_grpc_parser =
+ reinterpret_cast<grpc_sockaddr_in6*>(addr.addr);
+ // Compare the sockaddr returned from gRPC's ipv6 resolver with that returned
+ // from getaddrinfo.
+ GPR_ASSERT(result_from_grpc_parser->sin6_family == AF_INET6);
+ GPR_ASSERT(result_from_getaddrinfo.sin6_family == AF_INET6);
+ GPR_ASSERT(memcmp(&result_from_grpc_parser->sin6_addr,
+ &result_from_getaddrinfo.sin6_addr, sizeof(in6_addr)) == 0);
+ GPR_ASSERT(result_from_grpc_parser->sin6_scope_id ==
+ result_from_getaddrinfo.sin6_scope_id);
+ GPR_ASSERT(result_from_grpc_parser->sin6_scope_id != 0);
+ // TODO: compare sin6_flow_info fields? parse_ipv6 zero's this field as is.
+ // Cleanup
+ grpc_uri_destroy(uri);
+}
+
+struct sockaddr_in6 resolve_with_gettaddrinfo(const char* uri_text) {
+ grpc_uri* uri = grpc_uri_parse(uri_text, 0);
+ char* host = nullptr;
+ char* port = nullptr;
+ gpr_split_host_port(uri->path, &host, &port);
+ struct addrinfo hints;
+ memset(&hints, 0, sizeof(hints));
+ hints.ai_family = AF_INET6;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_flags = AI_NUMERICHOST;
+ struct addrinfo* result;
+ int res = getaddrinfo(host, port, &hints, &result);
+ if (res != 0) {
+ gpr_log(GPR_ERROR,
+ "getaddrinfo failed to resolve host:%s port:%s. Error: %d.", host,
+ port, res);
+ abort();
+ }
+ size_t num_addrs_from_getaddrinfo = 0;
+ for (struct addrinfo* resp = result; resp != nullptr; resp = resp->ai_next) {
+ num_addrs_from_getaddrinfo++;
+ }
+ GPR_ASSERT(num_addrs_from_getaddrinfo == 1);
+ GPR_ASSERT(result->ai_family == AF_INET6);
+ struct sockaddr_in6 out =
+ *reinterpret_cast<struct sockaddr_in6*>(result->ai_addr);
+ // Cleanup
+ freeaddrinfo(result);
+ gpr_free(host);
+ gpr_free(port);
+ grpc_uri_destroy(uri);
+ return out;
+}
+
+int main(int argc, char** argv) {
+ grpc_test_init(argc, argv);
+ grpc_init();
+ char* arbitrary_interface_name = static_cast<char*>(gpr_zalloc(IF_NAMESIZE));
+ // Per RFC 3493, an interface index is a "small positive integer starts at 1".
+ // Probe candidate interface index numbers until we find one that the
+ // system recognizes, and then use that for the test.
+ for (size_t i = 1; i < 65536; i++) {
+ if (if_indextoname(i, arbitrary_interface_name) != nullptr) {
+ gpr_log(
+ GPR_DEBUG,
+ "Found interface at index %d named %s. Will use this for the test",
+ (int)i, arbitrary_interface_name);
+ break;
+ }
+ }
+ GPR_ASSERT(strlen(arbitrary_interface_name) > 0);
+ char* target = nullptr;
+ gpr_asprintf(&target, "ipv6:[fe80::1234%%%s]:12345",
+ arbitrary_interface_name);
+ struct sockaddr_in6 result_from_getaddrinfo =
+ resolve_with_gettaddrinfo(target);
+ // Run the test
+ gpr_log(GPR_DEBUG,
+ "Run test_grpc_parse_ipv6_parity_with_getaddrinfo with target: %s",
+ target);
+ test_grpc_parse_ipv6_parity_with_getaddrinfo(target, result_from_getaddrinfo);
+ // Cleanup
+ gpr_free(target);
+ gpr_free(arbitrary_interface_name);
+ grpc_shutdown();
+}
diff --git a/test/core/end2end/fixtures/h2_full+trace.cc b/test/core/end2end/fixtures/h2_full+trace.cc
index 2bbad48701..ce8f6bf13a 100644
--- a/test/core/end2end/fixtures/h2_full+trace.cc
+++ b/test/core/end2end/fixtures/h2_full+trace.cc
@@ -113,6 +113,15 @@ int main(int argc, char** argv) {
g_fixture_slowdown_factor = 10;
#endif
+#ifdef GPR_WINDOWS
+ /* on Windows, writing logs to stderr is very slow
+ when stderr is redirected to a disk file.
+ The "trace" tests fixtures generates large amount
+ of logs, so setting a buffer for stderr prevents certain
+ test cases from timing out. */
+ setvbuf(stderr, NULL, _IOLBF, 1024);
+#endif
+
grpc::testing::TestEnvironment env(argc, argv);
grpc_end2end_tests_pre_init();
grpc_init();
diff --git a/test/core/end2end/fixtures/h2_sockpair+trace.cc b/test/core/end2end/fixtures/h2_sockpair+trace.cc
index 45f78b5964..4494d5c474 100644
--- a/test/core/end2end/fixtures/h2_sockpair+trace.cc
+++ b/test/core/end2end/fixtures/h2_sockpair+trace.cc
@@ -140,6 +140,15 @@ int main(int argc, char** argv) {
g_fixture_slowdown_factor = 10;
#endif
+#ifdef GPR_WINDOWS
+ /* on Windows, writing logs to stderr is very slow
+ when stderr is redirected to a disk file.
+ The "trace" tests fixtures generates large amount
+ of logs, so setting a buffer for stderr prevents certain
+ test cases from timing out. */
+ setvbuf(stderr, NULL, _IOLBF, 1024);
+#endif
+
grpc::testing::TestEnvironment env(argc, argv);
grpc_end2end_tests_pre_init();
grpc_init();
diff --git a/test/core/fling/BUILD b/test/core/fling/BUILD
index 5c6930cc85..0c16b2a879 100644
--- a/test/core/fling/BUILD
+++ b/test/core/fling/BUILD
@@ -21,7 +21,7 @@ licenses(["notice"]) # Apache v2
load("//test/core/util:grpc_fuzzer.bzl", "grpc_fuzzer")
grpc_cc_binary(
- name = "client",
+ name = "fling_client",
testonly = 1,
srcs = ["client.cc"],
language = "C++",
@@ -34,7 +34,7 @@ grpc_cc_binary(
)
grpc_cc_binary(
- name = "server",
+ name = "fling_server",
testonly = 1,
srcs = ["server.cc"],
language = "C++",
@@ -50,8 +50,8 @@ grpc_cc_test(
name = "fling",
srcs = ["fling_test.cc"],
data = [
- ":client",
- ":server",
+ ":fling_client",
+ ":fling_server",
],
deps = [
"//:gpr",
@@ -65,8 +65,8 @@ grpc_cc_test(
name = "fling_stream",
srcs = ["fling_stream_test.cc"],
data = [
- ":client",
- ":server",
+ ":fling_client",
+ ":fling_server",
],
deps = [
"//:gpr",
diff --git a/test/core/iomgr/BUILD b/test/core/iomgr/BUILD
index e920ceacf0..7daabd5052 100644
--- a/test/core/iomgr/BUILD
+++ b/test/core/iomgr/BUILD
@@ -128,8 +128,25 @@ grpc_cc_test(
)
grpc_cc_test(
- name = "resolve_address_posix_test",
+ name = "resolve_address_using_ares_resolver_posix_test",
srcs = ["resolve_address_posix_test.cc"],
+ args = [
+ "--resolver=ares",
+ ],
+ language = "C++",
+ deps = [
+ "//:gpr",
+ "//:grpc",
+ "//test/core/util:grpc_test_util",
+ ],
+)
+
+grpc_cc_test(
+ name = "resolve_address_using_native_resolver_posix_test",
+ srcs = ["resolve_address_posix_test.cc"],
+ args = [
+ "--resolver=native",
+ ],
language = "C++",
deps = [
"//:gpr",
@@ -237,7 +254,6 @@ grpc_cc_test(
name = "tcp_server_posix_test",
srcs = ["tcp_server_posix_test.cc"],
language = "C++",
- tags = ["manual"], # TODO(adelez): Remove once this works on Foundry.
deps = [
"//:gpr",
"//:grpc",
diff --git a/test/core/iomgr/resolve_address_posix_test.cc b/test/core/iomgr/resolve_address_posix_test.cc
index 5785c73e22..826c7e1faf 100644
--- a/test/core/iomgr/resolve_address_posix_test.cc
+++ b/test/core/iomgr/resolve_address_posix_test.cc
@@ -18,12 +18,14 @@
#include "src/core/lib/iomgr/resolve_address.h"
+#include <net/if.h>
#include <string.h>
#include <sys/un.h>
#include <grpc/grpc.h>
#include <grpc/support/alloc.h>
#include <grpc/support/log.h>
+#include <grpc/support/string_util.h>
#include <grpc/support/sync.h>
#include <grpc/support/time.h>
@@ -33,6 +35,7 @@
#include "src/core/lib/gprpp/thd.h"
#include "src/core/lib/iomgr/executor.h"
#include "src/core/lib/iomgr/iomgr.h"
+#include "test/core/util/cmdline.h"
#include "test/core/util/test_config.h"
static gpr_timespec test_deadline(void) {
@@ -117,12 +120,18 @@ static void must_succeed(void* argsp, grpc_error* err) {
GPR_ASSERT(args->addrs != nullptr);
GPR_ASSERT(args->addrs->naddrs > 0);
gpr_atm_rel_store(&args->done_atm, 1);
+ gpr_mu_lock(args->mu);
+ GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr));
+ gpr_mu_unlock(args->mu);
}
static void must_fail(void* argsp, grpc_error* err) {
args_struct* args = static_cast<args_struct*>(argsp);
GPR_ASSERT(err != GRPC_ERROR_NONE);
gpr_atm_rel_store(&args->done_atm, 1);
+ gpr_mu_lock(args->mu);
+ GRPC_LOG_IF_ERROR("pollset_kick", grpc_pollset_kick(args->pollset, nullptr));
+ gpr_mu_unlock(args->mu);
}
static void test_unix_socket(void) {
@@ -159,22 +168,92 @@ static void test_unix_socket_path_name_too_long(void) {
args_finish(&args);
}
+static void resolve_address_must_succeed(const char* target) {
+ grpc_core::ExecCtx exec_ctx;
+ args_struct args;
+ args_init(&args);
+ poll_pollset_until_request_done(&args);
+ grpc_resolve_address(
+ target, "1" /* port number */, args.pollset_set,
+ GRPC_CLOSURE_CREATE(must_succeed, &args, grpc_schedule_on_exec_ctx),
+ &args.addrs);
+ grpc_core::ExecCtx::Get()->Flush();
+ args_finish(&args);
+}
+
+static void test_named_and_numeric_scope_ids(void) {
+ char* arbitrary_interface_name = static_cast<char*>(gpr_zalloc(IF_NAMESIZE));
+ int interface_index = 0;
+ // Probe candidate interface index numbers until we find one that the
+ // system recognizes, and then use that for the test.
+ for (size_t i = 1; i < 65536; i++) {
+ if (if_indextoname(i, arbitrary_interface_name) != nullptr) {
+ gpr_log(
+ GPR_DEBUG,
+ "Found interface at index %d named %s. Will use this for the test",
+ (int)i, arbitrary_interface_name);
+ interface_index = (int)i;
+ break;
+ }
+ }
+ GPR_ASSERT(strlen(arbitrary_interface_name) > 0);
+ // Test resolution of an ipv6 address with a named scope ID
+ gpr_log(GPR_DEBUG, "test resolution with a named scope ID");
+ char* target_with_named_scope_id = nullptr;
+ gpr_asprintf(&target_with_named_scope_id, "fe80::1234%%%s",
+ arbitrary_interface_name);
+ resolve_address_must_succeed(target_with_named_scope_id);
+ gpr_free(target_with_named_scope_id);
+ gpr_free(arbitrary_interface_name);
+ // Test resolution of an ipv6 address with a numeric scope ID
+ gpr_log(GPR_DEBUG, "test resolution with a numeric scope ID");
+ char* target_with_numeric_scope_id = nullptr;
+ gpr_asprintf(&target_with_numeric_scope_id, "fe80::1234%%%d",
+ interface_index);
+ resolve_address_must_succeed(target_with_numeric_scope_id);
+ gpr_free(target_with_numeric_scope_id);
+}
+
int main(int argc, char** argv) {
+ // First set the resolver type based off of --resolver
+ const char* resolver_type = nullptr;
+ gpr_cmdline* cl = gpr_cmdline_create("resolve address test");
+ gpr_cmdline_add_string(cl, "resolver", "Resolver type (ares or native)",
+ &resolver_type);
+ // In case that there are more than one argument on the command line,
+ // --resolver will always be the first one, so only parse the first argument
+ // (other arguments may be unknown to cl)
+ gpr_cmdline_parse(cl, argc > 2 ? 2 : argc, argv);
+ const char* cur_resolver = gpr_getenv("GRPC_DNS_RESOLVER");
+ if (cur_resolver != nullptr && strlen(cur_resolver) != 0) {
+ gpr_log(GPR_INFO, "Warning: overriding resolver setting of %s",
+ cur_resolver);
+ }
+ if (gpr_stricmp(resolver_type, "native") == 0) {
+ gpr_setenv("GRPC_DNS_RESOLVER", "native");
+ } else if (gpr_stricmp(resolver_type, "ares") == 0) {
+ gpr_setenv("GRPC_DNS_RESOLVER", "ares");
+ } else {
+ gpr_log(GPR_ERROR, "--resolver_type was not set to ares or native");
+ abort();
+ }
grpc::testing::TestEnvironment env(argc, argv);
grpc_init();
{
grpc_core::ExecCtx exec_ctx;
- char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER");
+ test_named_and_numeric_scope_ids();
// c-ares resolver doesn't support UDS (ability for native DNS resolver
// to handle this is only expected to be used by servers, which
// unconditionally use the native DNS resolver).
+ char* resolver_env = gpr_getenv("GRPC_DNS_RESOLVER");
if (resolver_env == nullptr || gpr_stricmp(resolver_env, "native") == 0) {
test_unix_socket();
test_unix_socket_path_name_too_long();
}
gpr_free(resolver_env);
}
+ gpr_cmdline_destroy(cl);
grpc_shutdown();
return 0;
diff --git a/test/core/iomgr/tcp_server_posix_test.cc b/test/core/iomgr/tcp_server_posix_test.cc
index 2c66cdec77..81e26b20cd 100644
--- a/test/core/iomgr/tcp_server_posix_test.cc
+++ b/test/core/iomgr/tcp_server_posix_test.cc
@@ -439,6 +439,11 @@ int main(int argc, char** argv) {
static_cast<test_addrs*>(gpr_zalloc(sizeof(*dst_addrs)));
grpc::testing::TestEnvironment env(argc, argv);
grpc_init();
+ // wait a few seconds to make sure IPv6 link-local addresses can be bound
+ // if we are running under docker container that has just started.
+ // See https://github.com/moby/moby/issues/38491
+ // See https://github.com/grpc/grpc/issues/15610
+ gpr_sleep_until(grpc_timeout_seconds_to_deadline(4));
{
grpc_core::ExecCtx exec_ctx;
g_pollset = static_cast<grpc_pollset*>(gpr_zalloc(grpc_pollset_size()));
diff --git a/test/core/memory_usage/BUILD b/test/core/memory_usage/BUILD
index 2fe94dfa12..38b088c75c 100644
--- a/test/core/memory_usage/BUILD
+++ b/test/core/memory_usage/BUILD
@@ -19,7 +19,7 @@ grpc_package(name = "test/core/memory_usage")
licenses(["notice"]) # Apache v2
grpc_cc_library(
- name = "client",
+ name = "memory_usage_client",
testonly = 1,
srcs = ["client.cc"],
deps = [
@@ -30,7 +30,7 @@ grpc_cc_library(
)
grpc_cc_library(
- name = "server",
+ name = "memory_usage_server",
testonly = 1,
srcs = ["server.cc"],
deps = [
@@ -45,8 +45,8 @@ grpc_cc_test(
name = "memory_usage_test",
srcs = ["memory_usage_test.cc"],
data = [
- ":client",
- ":server",
+ ":memory_usage_client",
+ ":memory_usage_server",
],
language = "C++",
deps = [
diff --git a/test/cpp/end2end/client_interceptors_end2end_test.cc b/test/cpp/end2end/client_interceptors_end2end_test.cc
index 8abf4eb3f4..177922f457 100644
--- a/test/cpp/end2end/client_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/client_interceptors_end2end_test.cc
@@ -68,7 +68,7 @@ class HijackingInterceptor : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
- auto* buffer = methods->GetSendMessage();
+ auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
@@ -173,7 +173,7 @@ class HijackingInterceptorMakesAnotherCall : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
- auto* buffer = methods->GetSendMessage();
+ auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
@@ -270,6 +270,235 @@ class HijackingInterceptorMakesAnotherCallFactory
}
};
+class BidiStreamingRpcHijackingInterceptor : public experimental::Interceptor {
+ public:
+ BidiStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ CheckMetadata(*methods->GetSendInitialMetadata(), "testkey", "testvalue");
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message().find("Hello"), 0u);
+ msg = req.message();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ CheckMetadata(*methods->GetRecvTrailingMetadata(), "testkey",
+ "testvalue");
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message(msg);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ EXPECT_EQ(static_cast<EchoResponse*>(methods->GetRecvMessage())
+ ->message()
+ .find("Hello"),
+ 0u);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ grpc::string msg;
+};
+
+class ClientStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ClientStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedSendMessage();
+ }
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_SEND_MESSAGE)) {
+ EXPECT_FALSE(got_failed_send_);
+ got_failed_send_ = !methods->GetSendMessageStatus();
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::UNAVAILABLE, "Done sending 10 messages");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedSend() { return got_failed_send_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ int count_ = 0;
+ static bool got_failed_send_;
+};
+
+bool ClientStreamingRpcHijackingInterceptor::got_failed_send_ = false;
+
+class ClientStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ClientStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class ServerStreamingRpcHijackingInterceptor
+ : public experimental::Interceptor {
+ public:
+ ServerStreamingRpcHijackingInterceptor(experimental::ClientRpcInfo* info) {
+ info_ = info;
+ }
+
+ virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ bool hijack = false;
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
+ auto* map = methods->GetSendInitialMetadata();
+ // Check that we can see the test metadata
+ ASSERT_EQ(map->size(), static_cast<unsigned>(1));
+ auto iterator = map->begin();
+ EXPECT_EQ("testkey", iterator->first);
+ EXPECT_EQ("testvalue", iterator->second);
+ hijack = true;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ EchoRequest req;
+ auto* buffer = methods->GetSerializedSendMessage();
+ auto copied_buffer = *buffer;
+ EXPECT_TRUE(
+ SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
+ .ok());
+ EXPECT_EQ(req.message(), "Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
+ // Got nothing to do here for now
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ bool found = false;
+ // Check that we received the metadata as an echo
+ for (const auto& pair : *map) {
+ found = pair.first.starts_with("testkey") &&
+ pair.second.starts_with("testvalue");
+ if (found) break;
+ }
+ EXPECT_EQ(found, true);
+ auto* status = methods->GetRecvStatus();
+ EXPECT_EQ(status->ok(), true);
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_MESSAGE)) {
+ if (++count_ > 10) {
+ methods->FailHijackedRecvMessage();
+ }
+ EchoResponse* resp =
+ static_cast<EchoResponse*>(methods->GetRecvMessage());
+ resp->set_message("Hello");
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
+ // Only the last message will be a failure
+ EXPECT_FALSE(got_failed_message_);
+ got_failed_message_ = methods->GetRecvMessage() == nullptr;
+ }
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_RECV_STATUS)) {
+ auto* map = methods->GetRecvTrailingMetadata();
+ // insert the metadata that we want
+ EXPECT_EQ(map->size(), static_cast<unsigned>(0));
+ map->insert(std::make_pair("testkey", "testvalue"));
+ auto* status = methods->GetRecvStatus();
+ *status = Status(StatusCode::OK, "");
+ }
+ if (hijack) {
+ methods->Hijack();
+ } else {
+ methods->Proceed();
+ }
+ }
+
+ static bool GotFailedMessage() { return got_failed_message_; }
+
+ private:
+ experimental::ClientRpcInfo* info_;
+ static bool got_failed_message_;
+ int count_ = 0;
+};
+
+bool ServerStreamingRpcHijackingInterceptor::got_failed_message_ = false;
+
+class ServerStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new ServerStreamingRpcHijackingInterceptor(info);
+ }
+};
+
+class BidiStreamingRpcHijackingInterceptorFactory
+ : public experimental::ClientInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateClientInterceptor(
+ experimental::ClientRpcInfo* info) override {
+ return new BidiStreamingRpcHijackingInterceptor(info);
+ }
+};
+
class LoggingInterceptor : public experimental::Interceptor {
public:
LoggingInterceptor(experimental::ClientRpcInfo* info) { info_ = info; }
@@ -287,12 +516,16 @@ class LoggingInterceptor : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
- auto* buffer = methods->GetSendMessage();
+ EXPECT_EQ(static_cast<const EchoRequest*>(methods->GetSendMessage())
+ ->message()
+ .find("Hello"),
+ 0u);
+ auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
.ok());
- EXPECT_TRUE(req.message().find("Hello") == 0);
+ EXPECT_TRUE(req.message().find("Hello") == 0u);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_CLOSE)) {
@@ -308,7 +541,7 @@ class LoggingInterceptor : public experimental::Interceptor {
experimental::InterceptionHookPoints::POST_RECV_MESSAGE)) {
EchoResponse* resp =
static_cast<EchoResponse*>(methods->GetRecvMessage());
- EXPECT_TRUE(resp->message().find("Hello") == 0);
+ EXPECT_TRUE(resp->message().find("Hello") == 0u);
}
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::POST_RECV_STATUS)) {
@@ -546,6 +779,62 @@ TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingTest) {
EXPECT_EQ(DummyInterceptor::GetNumTimesRun(), 20);
}
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ClientStreamingHijackingTest) {
+ ChannelArguments args;
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ClientStreamingRpcHijackingInterceptorFactory>(
+ new ClientStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+
+ auto stub = grpc::testing::EchoTestService::NewStub(channel);
+ ClientContext ctx;
+ EchoRequest req;
+ EchoResponse resp;
+ req.mutable_param()->set_echo_metadata(true);
+ req.set_message("Hello");
+ string expected_resp = "";
+ auto writer = stub->RequestStream(&ctx, &resp);
+ for (int i = 0; i < 10; i++) {
+ EXPECT_TRUE(writer->Write(req));
+ expected_resp += "Hello";
+ }
+ // The interceptor will reject the 11th message
+ writer->Write(req);
+ Status s = writer->Finish();
+ EXPECT_EQ(s.ok(), false);
+ EXPECT_TRUE(ClientStreamingRpcHijackingInterceptor::GotFailedSend());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, ServerStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<ServerStreamingRpcHijackingInterceptorFactory>(
+ new ServerStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeServerStreamingCall(channel);
+ EXPECT_TRUE(ServerStreamingRpcHijackingInterceptor::GotFailedMessage());
+}
+
+TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingHijackingTest) {
+ ChannelArguments args;
+ DummyInterceptor::Reset();
+ std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
+ creators;
+ creators.push_back(
+ std::unique_ptr<BidiStreamingRpcHijackingInterceptorFactory>(
+ new BidiStreamingRpcHijackingInterceptorFactory()));
+ auto channel = experimental::CreateCustomChannelWithInterceptors(
+ server_address_, InsecureChannelCredentials(), args, std::move(creators));
+ MakeBidiStreamingCall(channel);
+}
+
TEST_F(ClientInterceptorsStreamingEnd2endTest, BidiStreamingTest) {
ChannelArguments args;
DummyInterceptor::Reset();
diff --git a/test/cpp/end2end/interceptors_util.cc b/test/cpp/end2end/interceptors_util.cc
index e0ad7d1526..900f02b5f3 100644
--- a/test/cpp/end2end/interceptors_util.cc
+++ b/test/cpp/end2end/interceptors_util.cc
@@ -132,6 +132,16 @@ bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
return false;
}
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value) {
+ for (const auto& pair : map) {
+ if (pair.first == key && pair.second == value) {
+ return true;
+ }
+ }
+ return false;
+}
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors() {
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
diff --git a/test/cpp/end2end/interceptors_util.h b/test/cpp/end2end/interceptors_util.h
index 659e613d2e..419845e5f6 100644
--- a/test/cpp/end2end/interceptors_util.h
+++ b/test/cpp/end2end/interceptors_util.h
@@ -165,6 +165,9 @@ void MakeCallbackCall(const std::shared_ptr<Channel>& channel);
bool CheckMetadata(const std::multimap<grpc::string_ref, grpc::string_ref>& map,
const string& key, const string& value);
+bool CheckMetadata(const std::multimap<grpc::string, grpc::string>& map,
+ const string& key, const string& value);
+
std::vector<std::unique_ptr<experimental::ClientInterceptorFactoryInterface>>
CreateDummyClientInterceptors();
diff --git a/test/cpp/end2end/server_interceptors_end2end_test.cc b/test/cpp/end2end/server_interceptors_end2end_test.cc
index 53d8c4dc96..82f142ba91 100644
--- a/test/cpp/end2end/server_interceptors_end2end_test.cc
+++ b/test/cpp/end2end/server_interceptors_end2end_test.cc
@@ -73,7 +73,7 @@ class LoggingInterceptor : public experimental::Interceptor {
type == experimental::ServerRpcInfo::Type::BIDI_STREAMING));
}
- virtual void Intercept(experimental::InterceptorBatchMethods* methods) {
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_INITIAL_METADATA)) {
auto* map = methods->GetSendInitialMetadata();
@@ -83,7 +83,7 @@ class LoggingInterceptor : public experimental::Interceptor {
if (methods->QueryInterceptionHookPoint(
experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
EchoRequest req;
- auto* buffer = methods->GetSendMessage();
+ auto* buffer = methods->GetSerializedSendMessage();
auto copied_buffer = *buffer;
EXPECT_TRUE(
SerializationTraits<EchoRequest>::Deserialize(&copied_buffer, &req)
@@ -142,6 +142,71 @@ class LoggingInterceptorFactory
}
};
+// Test if SendMessage function family works as expected for sync/callback apis
+class SyncSendMessageTester : public experimental::Interceptor {
+ public:
+ SyncSendMessageTester(experimental::ServerRpcInfo* info) {}
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ string old_msg =
+ static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
+ EXPECT_EQ(old_msg.find("Hello"), 0u);
+ new_msg_.set_message("World" + old_msg);
+ methods->ModifySendMessage(&new_msg_);
+ }
+ methods->Proceed();
+ }
+
+ private:
+ EchoRequest new_msg_;
+};
+
+class SyncSendMessageTesterFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new SyncSendMessageTester(info);
+ }
+};
+
+// Test if SendMessage function family works as expected for sync/callback apis
+class SyncSendMessageVerifier : public experimental::Interceptor {
+ public:
+ SyncSendMessageVerifier(experimental::ServerRpcInfo* info) {}
+
+ void Intercept(experimental::InterceptorBatchMethods* methods) override {
+ if (methods->QueryInterceptionHookPoint(
+ experimental::InterceptionHookPoints::PRE_SEND_MESSAGE)) {
+ // Make sure that the changes made in SyncSendMessageTester persisted
+ string old_msg =
+ static_cast<const EchoRequest*>(methods->GetSendMessage())->message();
+ EXPECT_EQ(old_msg.find("World"), 0u);
+
+ // Remove the "World" part of the string that we added earlier
+ new_msg_.set_message(old_msg.erase(0, 5));
+ methods->ModifySendMessage(&new_msg_);
+
+ // LoggingInterceptor verifies that changes got reverted
+ }
+ methods->Proceed();
+ }
+
+ private:
+ EchoRequest new_msg_;
+};
+
+class SyncSendMessageVerifierFactory
+ : public experimental::ServerInterceptorFactoryInterface {
+ public:
+ virtual experimental::Interceptor* CreateServerInterceptor(
+ experimental::ServerRpcInfo* info) override {
+ return new SyncSendMessageVerifier(info);
+ }
+};
+
void MakeBidiStreamingCall(const std::shared_ptr<Channel>& channel) {
auto stub = grpc::testing::EchoTestService::NewStub(channel);
ClientContext ctx;
@@ -175,6 +240,12 @@ class ServerInterceptorsEnd2endSyncUnaryTest : public ::testing::Test {
creators;
creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageTesterFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageVerifierFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory()));
// Add 20 dummy interceptor factories and null interceptor factories
for (auto i = 0; i < 20; i++) {
@@ -215,6 +286,12 @@ class ServerInterceptorsEnd2endSyncStreamingTest : public ::testing::Test {
creators;
creators.push_back(
std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageTesterFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
+ new SyncSendMessageVerifierFactory()));
+ creators.push_back(
+ std::unique_ptr<experimental::ServerInterceptorFactoryInterface>(
new LoggingInterceptorFactory()));
for (auto i = 0; i < 20; i++) {
creators.push_back(std::unique_ptr<DummyInterceptorFactory>(
diff --git a/test/cpp/microbenchmarks/bm_call_create.cc b/test/cpp/microbenchmarks/bm_call_create.cc
index 8d12606434..125b1ce5c4 100644
--- a/test/cpp/microbenchmarks/bm_call_create.cc
+++ b/test/cpp/microbenchmarks/bm_call_create.cc
@@ -326,7 +326,7 @@ class FakeClientChannelFactory : public grpc_client_channel_factory {
static void NoRef(grpc_client_channel_factory* factory) {}
static void NoUnref(grpc_client_channel_factory* factory) {}
static grpc_subchannel* CreateSubchannel(grpc_client_channel_factory* factory,
- const grpc_subchannel_args* args) {
+ const grpc_channel_args* args) {
return nullptr;
}
static grpc_channel* CreateClientChannel(grpc_client_channel_factory* factory,
diff --git a/test/cpp/qps/client.h b/test/cpp/qps/client.h
index 73f91eed2d..ceb5cdd710 100644
--- a/test/cpp/qps/client.h
+++ b/test/cpp/qps/client.h
@@ -236,58 +236,7 @@ class Client {
return 0;
}
- protected:
- bool closed_loop_;
- gpr_atm thread_pool_done_;
- double median_latency_collection_interval_seconds_; // In seconds
-
- void StartThreads(size_t num_threads) {
- gpr_atm_rel_store(&thread_pool_done_, static_cast<gpr_atm>(false));
- threads_remaining_ = num_threads;
- for (size_t i = 0; i < num_threads; i++) {
- threads_.emplace_back(new Thread(this, i));
- }
- }
-
- void EndThreads() {
- MaybeStartRequests();
- threads_.clear();
- }
-
- virtual void DestroyMultithreading() = 0;
-
- void SetupLoadTest(const ClientConfig& config, size_t num_threads) {
- // Set up the load distribution based on the number of threads
- const auto& load = config.load_params();
-
- std::unique_ptr<RandomDistInterface> random_dist;
- switch (load.load_case()) {
- case LoadParams::kClosedLoop:
- // Closed-loop doesn't use random dist at all
- break;
- case LoadParams::kPoisson:
- random_dist.reset(
- new ExpDist(load.poisson().offered_load() / num_threads));
- break;
- default:
- GPR_ASSERT(false);
- }
-
- // Set closed_loop_ based on whether or not random_dist is set
- if (!random_dist) {
- closed_loop_ = true;
- } else {
- closed_loop_ = false;
- // set up interarrival timer according to random dist
- interarrival_timer_.init(*random_dist, num_threads);
- const auto now = gpr_now(GPR_CLOCK_MONOTONIC);
- for (size_t i = 0; i < num_threads; i++) {
- next_time_.push_back(gpr_time_add(
- now,
- gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN)));
- }
- }
- }
+ bool IsClosedLoop() { return closed_loop_; }
gpr_timespec NextIssueTime(int thread_idx) {
const gpr_timespec result = next_time_[thread_idx];
@@ -297,9 +246,9 @@ class Client {
GPR_TIMESPAN));
return result;
}
- std::function<gpr_timespec()> NextIssuer(int thread_idx) {
- return closed_loop_ ? std::function<gpr_timespec()>()
- : std::bind(&Client::NextIssueTime, this, thread_idx);
+
+ bool ThreadCompleted() {
+ return static_cast<bool>(gpr_atm_acq_load(&thread_pool_done_));
}
class Thread {
@@ -380,8 +329,62 @@ class Client {
double interval_start_time_;
};
- bool ThreadCompleted() {
- return static_cast<bool>(gpr_atm_acq_load(&thread_pool_done_));
+ protected:
+ bool closed_loop_;
+ gpr_atm thread_pool_done_;
+ double median_latency_collection_interval_seconds_; // In seconds
+
+ void StartThreads(size_t num_threads) {
+ gpr_atm_rel_store(&thread_pool_done_, static_cast<gpr_atm>(false));
+ threads_remaining_ = num_threads;
+ for (size_t i = 0; i < num_threads; i++) {
+ threads_.emplace_back(new Thread(this, i));
+ }
+ }
+
+ void EndThreads() {
+ MaybeStartRequests();
+ threads_.clear();
+ }
+
+ virtual void DestroyMultithreading() = 0;
+
+ void SetupLoadTest(const ClientConfig& config, size_t num_threads) {
+ // Set up the load distribution based on the number of threads
+ const auto& load = config.load_params();
+
+ std::unique_ptr<RandomDistInterface> random_dist;
+ switch (load.load_case()) {
+ case LoadParams::kClosedLoop:
+ // Closed-loop doesn't use random dist at all
+ break;
+ case LoadParams::kPoisson:
+ random_dist.reset(
+ new ExpDist(load.poisson().offered_load() / num_threads));
+ break;
+ default:
+ GPR_ASSERT(false);
+ }
+
+ // Set closed_loop_ based on whether or not random_dist is set
+ if (!random_dist) {
+ closed_loop_ = true;
+ } else {
+ closed_loop_ = false;
+ // set up interarrival timer according to random dist
+ interarrival_timer_.init(*random_dist, num_threads);
+ const auto now = gpr_now(GPR_CLOCK_MONOTONIC);
+ for (size_t i = 0; i < num_threads; i++) {
+ next_time_.push_back(gpr_time_add(
+ now,
+ gpr_time_from_nanos(interarrival_timer_.next(i), GPR_TIMESPAN)));
+ }
+ }
+ }
+
+ std::function<gpr_timespec()> NextIssuer(int thread_idx) {
+ return closed_loop_ ? std::function<gpr_timespec()>()
+ : std::bind(&Client::NextIssueTime, this, thread_idx);
}
virtual void ThreadFunc(size_t thread_idx, Client::Thread* t) = 0;
@@ -436,6 +439,7 @@ class ClientImpl : public Client {
config.payload_config());
}
virtual ~ClientImpl() {}
+ const RequestType* request() { return &request_; }
void WaitForChannelsToConnect() {
int connect_deadline_seconds = 10;
diff --git a/test/cpp/qps/client_callback.cc b/test/cpp/qps/client_callback.cc
index 87889e36dc..4a06325f2b 100644
--- a/test/cpp/qps/client_callback.cc
+++ b/test/cpp/qps/client_callback.cc
@@ -66,13 +66,35 @@ class CallbackClient
config, BenchmarkStubCreator) {
num_threads_ = NumThreads(config);
rpcs_done_ = 0;
- SetupLoadTest(config, num_threads_);
+
+ // Don't divide the fixed load among threads as the user threads
+ // only bootstrap the RPCs
+ SetupLoadTest(config, 1);
total_outstanding_rpcs_ =
config.client_channels() * config.outstanding_rpcs_per_channel();
}
virtual ~CallbackClient() {}
+ /**
+ * The main thread of the benchmark will be waiting on DestroyMultithreading.
+ * Increment the rpcs_done_ variable to signify that the Callback RPC
+ * after thread completion is done. When the last outstanding rpc increments
+ * the counter it should also signal the main thread's conditional variable.
+ */
+ void NotifyMainThreadOfThreadCompletion() {
+ std::lock_guard<std::mutex> l(shutdown_mu_);
+ rpcs_done_++;
+ if (rpcs_done_ == total_outstanding_rpcs_) {
+ shutdown_cv_.notify_one();
+ }
+ }
+
+ gpr_timespec NextRPCIssueTime() {
+ std::lock_guard<std::mutex> l(next_issue_time_mu_);
+ return Client::NextIssueTime(0);
+ }
+
protected:
size_t num_threads_;
size_t total_outstanding_rpcs_;
@@ -93,24 +115,9 @@ class CallbackClient
ThreadFuncImpl(t, thread_idx);
}
- virtual void ScheduleRpc(Thread* t, size_t thread_idx,
- size_t ctx_vector_idx) = 0;
-
- /**
- * The main thread of the benchmark will be waiting on DestroyMultithreading.
- * Increment the rpcs_done_ variable to signify that the Callback RPC
- * after thread completion is done. When the last outstanding rpc increments
- * the counter it should also signal the main thread's conditional variable.
- */
- void NotifyMainThreadOfThreadCompletion() {
- std::lock_guard<std::mutex> l(shutdown_mu_);
- rpcs_done_++;
- if (rpcs_done_ == total_outstanding_rpcs_) {
- shutdown_cv_.notify_one();
- }
- }
-
private:
+ std::mutex next_issue_time_mu_; // Used by next issue time
+
int NumThreads(const ClientConfig& config) {
int num_threads = config.async_client_threads();
if (num_threads <= 0) { // Use dynamic sizing
@@ -149,7 +156,7 @@ class CallbackUnaryClient final : public CallbackClient {
bool ThreadFuncImpl(Thread* t, size_t thread_idx) override {
for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_;
vector_idx += num_threads_) {
- ScheduleRpc(t, thread_idx, vector_idx);
+ ScheduleRpc(t, vector_idx);
}
return true;
}
@@ -157,26 +164,26 @@ class CallbackUnaryClient final : public CallbackClient {
void InitThreadFuncImpl(size_t thread_idx) override { return; }
private:
- void ScheduleRpc(Thread* t, size_t thread_idx, size_t vector_idx) override {
+ void ScheduleRpc(Thread* t, size_t vector_idx) {
if (!closed_loop_) {
- gpr_timespec next_issue_time = NextIssueTime(thread_idx);
+ gpr_timespec next_issue_time = NextRPCIssueTime();
// Start an alarm callback to run the internal callback after
// next_issue_time
ctx_[vector_idx]->alarm_.experimental().Set(
- next_issue_time, [this, t, thread_idx, vector_idx](bool ok) {
- IssueUnaryCallbackRpc(t, thread_idx, vector_idx);
+ next_issue_time, [this, t, vector_idx](bool ok) {
+ IssueUnaryCallbackRpc(t, vector_idx);
});
} else {
- IssueUnaryCallbackRpc(t, thread_idx, vector_idx);
+ IssueUnaryCallbackRpc(t, vector_idx);
}
}
- void IssueUnaryCallbackRpc(Thread* t, size_t thread_idx, size_t vector_idx) {
+ void IssueUnaryCallbackRpc(Thread* t, size_t vector_idx) {
GPR_TIMER_SCOPE("CallbackUnaryClient::ThreadFunc", 0);
double start = UsageTimer::Now();
ctx_[vector_idx]->stub_->experimental_async()->UnaryCall(
(&ctx_[vector_idx]->context_), &request_, &ctx_[vector_idx]->response_,
- [this, t, thread_idx, start, vector_idx](grpc::Status s) {
+ [this, t, start, vector_idx](grpc::Status s) {
// Update Histogram with data from the callback run
HistogramEntry entry;
if (s.ok()) {
@@ -193,17 +200,157 @@ class CallbackUnaryClient final : public CallbackClient {
ctx_[vector_idx].reset(
new CallbackClientRpcContext(ctx_[vector_idx]->stub_));
// Schedule a new RPC
- ScheduleRpc(t, thread_idx, vector_idx);
+ ScheduleRpc(t, vector_idx);
}
});
}
};
+class CallbackStreamingClient : public CallbackClient {
+ public:
+ CallbackStreamingClient(const ClientConfig& config)
+ : CallbackClient(config),
+ messages_per_stream_(config.messages_per_stream()) {
+ for (int ch = 0; ch < config.client_channels(); ch++) {
+ for (int i = 0; i < config.outstanding_rpcs_per_channel(); i++) {
+ ctx_.emplace_back(
+ new CallbackClientRpcContext(channels_[ch].get_stub()));
+ }
+ }
+ StartThreads(num_threads_);
+ }
+ ~CallbackStreamingClient() {}
+
+ void AddHistogramEntry(double start_, bool ok, Thread* thread_ptr) {
+ // Update Histogram with data from the callback run
+ HistogramEntry entry;
+ if (ok) {
+ entry.set_value((UsageTimer::Now() - start_) * 1e9);
+ }
+ thread_ptr->UpdateHistogram(&entry);
+ }
+
+ int messages_per_stream() { return messages_per_stream_; }
+
+ protected:
+ const int messages_per_stream_;
+};
+
+class CallbackStreamingPingPongClient : public CallbackStreamingClient {
+ public:
+ CallbackStreamingPingPongClient(const ClientConfig& config)
+ : CallbackStreamingClient(config) {}
+ ~CallbackStreamingPingPongClient() {}
+};
+
+class CallbackStreamingPingPongReactor final
+ : public grpc::experimental::ClientBidiReactor<SimpleRequest,
+ SimpleResponse> {
+ public:
+ CallbackStreamingPingPongReactor(
+ CallbackStreamingPingPongClient* client,
+ std::unique_ptr<CallbackClientRpcContext> ctx)
+ : client_(client), ctx_(std::move(ctx)), messages_issued_(0) {}
+
+ void StartNewRpc() {
+ if (client_->ThreadCompleted()) return;
+ start_ = UsageTimer::Now();
+ ctx_->stub_->experimental_async()->StreamingCall(&(ctx_->context_), this);
+ StartWrite(client_->request());
+ StartCall();
+ }
+
+ void OnWriteDone(bool ok) override {
+ if (!ok || client_->ThreadCompleted()) {
+ if (!ok) gpr_log(GPR_ERROR, "Error writing RPC");
+ StartWritesDone();
+ return;
+ }
+ StartRead(&ctx_->response_);
+ }
+
+ void OnReadDone(bool ok) override {
+ client_->AddHistogramEntry(start_, ok, thread_ptr_);
+
+ if (client_->ThreadCompleted() || !ok ||
+ (client_->messages_per_stream() != 0 &&
+ ++messages_issued_ >= client_->messages_per_stream())) {
+ if (!ok) {
+ gpr_log(GPR_ERROR, "Error reading RPC");
+ }
+ StartWritesDone();
+ return;
+ }
+ StartWrite(client_->request());
+ }
+
+ void OnDone(const Status& s) override {
+ if (client_->ThreadCompleted() || !s.ok()) {
+ client_->NotifyMainThreadOfThreadCompletion();
+ return;
+ }
+ ctx_.reset(new CallbackClientRpcContext(ctx_->stub_));
+ ScheduleRpc();
+ }
+
+ void ScheduleRpc() {
+ if (client_->ThreadCompleted()) return;
+
+ if (!client_->IsClosedLoop()) {
+ gpr_timespec next_issue_time = client_->NextRPCIssueTime();
+ // Start an alarm callback to run the internal callback after
+ // next_issue_time
+ ctx_->alarm_.experimental().Set(next_issue_time,
+ [this](bool ok) { StartNewRpc(); });
+ } else {
+ StartNewRpc();
+ }
+ }
+
+ void set_thread_ptr(Client::Thread* ptr) { thread_ptr_ = ptr; }
+
+ CallbackStreamingPingPongClient* client_;
+ std::unique_ptr<CallbackClientRpcContext> ctx_;
+ Client::Thread* thread_ptr_; // Needed to update histogram entries
+ double start_; // Track message start time
+ int messages_issued_; // Messages issued by this stream
+};
+
+class CallbackStreamingPingPongClientImpl final
+ : public CallbackStreamingPingPongClient {
+ public:
+ CallbackStreamingPingPongClientImpl(const ClientConfig& config)
+ : CallbackStreamingPingPongClient(config) {
+ for (size_t i = 0; i < total_outstanding_rpcs_; i++)
+ reactor_.emplace_back(
+ new CallbackStreamingPingPongReactor(this, std::move(ctx_[i])));
+ }
+ ~CallbackStreamingPingPongClientImpl() {}
+
+ bool ThreadFuncImpl(Client::Thread* t, size_t thread_idx) override {
+ for (size_t vector_idx = thread_idx; vector_idx < total_outstanding_rpcs_;
+ vector_idx += num_threads_) {
+ reactor_[vector_idx]->set_thread_ptr(t);
+ reactor_[vector_idx]->ScheduleRpc();
+ }
+ return true;
+ }
+
+ void InitThreadFuncImpl(size_t thread_idx) override {}
+
+ private:
+ std::vector<std::unique_ptr<CallbackStreamingPingPongReactor>> reactor_;
+};
+
+// TODO(mhaidry) : Implement Streaming from client, server and both ways
+
std::unique_ptr<Client> CreateCallbackClient(const ClientConfig& config) {
switch (config.rpc_type()) {
case UNARY:
return std::unique_ptr<Client>(new CallbackUnaryClient(config));
case STREAMING:
+ return std::unique_ptr<Client>(
+ new CallbackStreamingPingPongClientImpl(config));
case STREAMING_FROM_CLIENT:
case STREAMING_FROM_SERVER:
case STREAMING_BOTH_WAYS:
diff --git a/test/cpp/qps/gen_build_yaml.py b/test/cpp/qps/gen_build_yaml.py
index fb2caf5486..8ca0dc6a62 100755
--- a/test/cpp/qps/gen_build_yaml.py
+++ b/test/cpp/qps/gen_build_yaml.py
@@ -131,4 +131,4 @@ def generate_yaml():
}
-print(yaml.dump(generate_yaml())) \ No newline at end of file
+print(yaml.dump(generate_yaml()))