diff options
author | Craig Tiller <ctiller@google.com> | 2017-02-02 12:02:37 -0800 |
---|---|---|
committer | Craig Tiller <ctiller@google.com> | 2017-02-02 12:02:37 -0800 |
commit | 6c5d08f3ac6bb4fbdeb48723af891eac9e277156 (patch) | |
tree | 5471e400c9009e343ae881671d7a2069ac987a95 /test/cpp | |
parent | f7af2a9a0517ada19e0ce1a092ca064985f604e5 (diff) | |
parent | 6c1d43bdd2242e68b7e494cb7c6ff511c5dc76ef (diff) |
Merge branch 'bm_stream' into bm_perf
Diffstat (limited to 'test/cpp')
-rw-r--r-- | test/cpp/common/alarm_cpp_test.cc | 18 | ||||
-rw-r--r-- | test/cpp/grpclb/grpclb_test.cc | 12 | ||||
-rwxr-xr-x | test/cpp/qps/gen_build_yaml.py | 8 | ||||
-rw-r--r-- | test/cpp/qps/qps_openloop_test.cc | 2 | ||||
-rw-r--r-- | test/cpp/util/cli_call.cc | 175 | ||||
-rw-r--r-- | test/cpp/util/cli_call.h | 51 | ||||
-rw-r--r-- | test/cpp/util/grpc_cli.cc | 6 | ||||
-rw-r--r-- | test/cpp/util/grpc_tool.cc | 253 | ||||
-rw-r--r-- | test/cpp/util/grpc_tool_test.cc | 193 | ||||
-rw-r--r-- | test/cpp/util/proto_file_parser.cc | 32 | ||||
-rw-r--r-- | test/cpp/util/proto_file_parser.h | 3 |
11 files changed, 640 insertions, 113 deletions
diff --git a/test/cpp/common/alarm_cpp_test.cc b/test/cpp/common/alarm_cpp_test.cc index a05ac30b1c..41085174a4 100644 --- a/test/cpp/common/alarm_cpp_test.cc +++ b/test/cpp/common/alarm_cpp_test.cc @@ -43,12 +43,12 @@ namespace { TEST(AlarmTest, RegularExpiry) { CompletionQueue cq; void* junk = reinterpret_cast<void*>(1618033); - Alarm alarm(&cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(1), junk); + Alarm alarm(&cq, grpc_timeout_seconds_to_deadline(1), junk); void* output_tag; bool ok; const CompletionQueue::NextStatus status = cq.AsyncNext( - (void**)&output_tag, &ok, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(2)); + (void**)&output_tag, &ok, grpc_timeout_seconds_to_deadline(2)); EXPECT_EQ(status, CompletionQueue::GOT_EVENT); EXPECT_TRUE(ok); @@ -65,7 +65,7 @@ TEST(AlarmTest, RegularExpiryChrono) { void* output_tag; bool ok; const CompletionQueue::NextStatus status = cq.AsyncNext( - (void**)&output_tag, &ok, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(2)); + (void**)&output_tag, &ok, grpc_timeout_seconds_to_deadline(2)); EXPECT_EQ(status, CompletionQueue::GOT_EVENT); EXPECT_TRUE(ok); @@ -75,12 +75,12 @@ TEST(AlarmTest, RegularExpiryChrono) { TEST(AlarmTest, ZeroExpiry) { CompletionQueue cq; void* junk = reinterpret_cast<void*>(1618033); - Alarm alarm(&cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(0), junk); + Alarm alarm(&cq, grpc_timeout_seconds_to_deadline(0), junk); void* output_tag; bool ok; const CompletionQueue::NextStatus status = cq.AsyncNext( - (void**)&output_tag, &ok, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(0)); + (void**)&output_tag, &ok, grpc_timeout_seconds_to_deadline(0)); EXPECT_EQ(status, CompletionQueue::GOT_EVENT); EXPECT_TRUE(ok); @@ -90,12 +90,12 @@ TEST(AlarmTest, ZeroExpiry) { TEST(AlarmTest, NegativeExpiry) { CompletionQueue cq; void* junk = reinterpret_cast<void*>(1618033); - Alarm alarm(&cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(-1), junk); + Alarm alarm(&cq, grpc_timeout_seconds_to_deadline(-1), junk); void* output_tag; bool ok; const CompletionQueue::NextStatus status = cq.AsyncNext( - (void**)&output_tag, &ok, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(0)); + (void**)&output_tag, &ok, grpc_timeout_seconds_to_deadline(0)); EXPECT_EQ(status, CompletionQueue::GOT_EVENT); EXPECT_TRUE(ok); @@ -105,13 +105,13 @@ TEST(AlarmTest, NegativeExpiry) { TEST(AlarmTest, Cancellation) { CompletionQueue cq; void* junk = reinterpret_cast<void*>(1618033); - Alarm alarm(&cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(2), junk); + Alarm alarm(&cq, grpc_timeout_seconds_to_deadline(2), junk); alarm.Cancel(); void* output_tag; bool ok; const CompletionQueue::NextStatus status = cq.AsyncNext( - (void**)&output_tag, &ok, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(1)); + (void**)&output_tag, &ok, grpc_timeout_seconds_to_deadline(1)); EXPECT_EQ(status, CompletionQueue::GOT_EVENT); EXPECT_FALSE(ok); diff --git a/test/cpp/grpclb/grpclb_test.cc b/test/cpp/grpclb/grpclb_test.cc index a10cda7ddb..4b8a434c78 100644 --- a/test/cpp/grpclb/grpclb_test.cc +++ b/test/cpp/grpclb/grpclb_test.cc @@ -170,7 +170,7 @@ static grpc_slice build_response_payload_slice( static void drain_cq(grpc_completion_queue *cq) { grpc_event ev; do { - ev = grpc_completion_queue_next(cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(5), + ev = grpc_completion_queue_next(cq, grpc_timeout_seconds_to_deadline(5), NULL); } while (ev.type != GRPC_QUEUE_SHUTDOWN); } @@ -336,7 +336,7 @@ static void start_backend_server(server_fixture *sf) { GPR_ASSERT(GRPC_CALL_OK == error); gpr_log(GPR_INFO, "Server[%s] up", sf->servers_hostport); ev = grpc_completion_queue_next(sf->cq, - GRPC_TIMEOUT_SECONDS_TO_DEADLINE(60), NULL); + grpc_timeout_seconds_to_deadline(60), NULL); if (!ev.success) { gpr_log(GPR_INFO, "Server[%s] being torn down", sf->servers_hostport); cq_verifier_destroy(cqv); @@ -380,7 +380,7 @@ static void start_backend_server(server_fixture *sf) { error = grpc_call_start_batch(s, ops, (size_t)(op - ops), tag(102), NULL); GPR_ASSERT(GRPC_CALL_OK == error); ev = grpc_completion_queue_next( - sf->cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(3), NULL); + sf->cq, grpc_timeout_seconds_to_deadline(3), NULL); if (ev.type == GRPC_OP_COMPLETE && ev.success) { GPR_ASSERT(ev.tag = tag(102)); if (request_payload_recv == NULL) { @@ -410,7 +410,7 @@ static void start_backend_server(server_fixture *sf) { grpc_call_start_batch(s, ops, (size_t)(op - ops), tag(103), NULL); GPR_ASSERT(GRPC_CALL_OK == error); ev = grpc_completion_queue_next( - sf->cq, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(3), NULL); + sf->cq, grpc_timeout_seconds_to_deadline(3), NULL); if (ev.type == GRPC_OP_COMPLETE && ev.success) { GPR_ASSERT(ev.tag = tag(103)); } else { @@ -477,7 +477,7 @@ static void perform_request(client_fixture *cf) { grpc_slice host = grpc_slice_from_static_string("foo.test.google.fr:1234"); c = grpc_channel_create_call(cf->client, NULL, GRPC_PROPAGATE_DEFAULTS, cf->cq, grpc_slice_from_static_string("/foo"), - &host, GRPC_TIMEOUT_SECONDS_TO_DEADLINE(5), + &host, grpc_timeout_seconds_to_deadline(5), NULL); gpr_log(GPR_INFO, "Call 0x%" PRIxPTR " created", (intptr_t)c); GPR_ASSERT(c); @@ -605,7 +605,7 @@ static void teardown_server(server_fixture *sf) { gpr_log(GPR_INFO, "Server[%s] shutting down", sf->servers_hostport); grpc_server_shutdown_and_notify(sf->server, sf->cq, tag(1000)); GPR_ASSERT(grpc_completion_queue_pluck( - sf->cq, tag(1000), GRPC_TIMEOUT_SECONDS_TO_DEADLINE(5), NULL) + sf->cq, tag(1000), grpc_timeout_seconds_to_deadline(5), NULL) .type == GRPC_OP_COMPLETE); grpc_server_destroy(sf->server); gpr_thd_join(sf->tid); diff --git a/test/cpp/qps/gen_build_yaml.py b/test/cpp/qps/gen_build_yaml.py index 188d6196e5..2f035abedd 100755 --- a/test/cpp/qps/gen_build_yaml.py +++ b/test/cpp/qps/gen_build_yaml.py @@ -92,7 +92,8 @@ print yaml.dump({ 'defaults': 'boringssl', 'cpu_cost': guess_cpu(scenario_json, False), 'exclude_configs': ['tsan', 'asan'], - 'timeout_seconds': 6*60 + 'timeout_seconds': 6*60, + 'excluded_poll_engines': scenario_json.get('EXCLUDED_POLL_ENGINES', []) } for scenario_json in scenario_config.CXXLanguage().scenarios() if 'scalable' in scenario_json.get('CATEGORIES', []) @@ -109,8 +110,9 @@ print yaml.dump({ 'defaults': 'boringssl', 'cpu_cost': guess_cpu(scenario_json, True), 'exclude_configs': sorted(c for c in configs_from_yaml if c not in ('tsan', 'asan')), - 'timeout_seconds': 6*60 - } + 'timeout_seconds': 6*60, + 'excluded_poll_engines': scenario_json.get('EXCLUDED_POLL_ENGINES', []) + } for scenario_json in scenario_config.CXXLanguage().scenarios() if 'scalable' in scenario_json.get('CATEGORIES', []) ] diff --git a/test/cpp/qps/qps_openloop_test.cc b/test/cpp/qps/qps_openloop_test.cc index 8dc50ac6d8..70e2709ac0 100644 --- a/test/cpp/qps/qps_openloop_test.cc +++ b/test/cpp/qps/qps_openloop_test.cc @@ -56,7 +56,7 @@ static void RunQPS() { client_config.set_async_client_threads(8); client_config.set_rpc_type(STREAMING); client_config.mutable_load_params()->mutable_poisson()->set_offered_load( - 1000.0 / GRPC_TEST_SLOWDOWN_FACTOR); + 1000.0 / grpc_test_slowdown_factor()); ServerConfig server_config; server_config.set_server_type(ASYNC_SERVER); diff --git a/test/cpp/util/cli_call.cc b/test/cpp/util/cli_call.cc index a02a8b2ee2..4d045da098 100644 --- a/test/cpp/util/cli_call.cc +++ b/test/cpp/util/cli_call.cc @@ -37,8 +37,6 @@ #include <grpc++/channel.h> #include <grpc++/client_context.h> -#include <grpc++/completion_queue.h> -#include <grpc++/generic/generic_stub.h> #include <grpc++/support/byte_buffer.h> #include <grpc/grpc.h> #include <grpc/slice.h> @@ -56,55 +54,172 @@ Status CliCall::Call(std::shared_ptr<grpc::Channel> channel, const OutgoingMetadataContainer& metadata, IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_trailing_metadata) { - std::unique_ptr<grpc::GenericStub> stub(new grpc::GenericStub(channel)); - grpc::ClientContext ctx; + CliCall call(channel, method, metadata); + call.Write(request); + call.WritesDone(); + if (!call.Read(response, server_initial_metadata)) { + fprintf(stderr, "Failed to read response.\n"); + } + return call.Finish(server_trailing_metadata); +} + +CliCall::CliCall(std::shared_ptr<grpc::Channel> channel, + const grpc::string& method, + const OutgoingMetadataContainer& metadata) + : stub_(new grpc::GenericStub(channel)) { + gpr_mu_init(&write_mu_); + gpr_cv_init(&write_cv_); if (!metadata.empty()) { for (OutgoingMetadataContainer::const_iterator iter = metadata.begin(); iter != metadata.end(); ++iter) { - ctx.AddMetadata(iter->first, iter->second); + ctx_.AddMetadata(iter->first, iter->second); } } - grpc::CompletionQueue cq; - std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call( - stub->Call(&ctx, method, &cq, tag(1))); + call_ = stub_->Call(&ctx_, method, &cq_, tag(1)); void* got_tag; bool ok; - cq.Next(&got_tag, &ok); + cq_.Next(&got_tag, &ok); GPR_ASSERT(ok); +} + +CliCall::~CliCall() { + gpr_cv_destroy(&write_cv_); + gpr_mu_destroy(&write_mu_); +} + +void CliCall::Write(const grpc::string& request) { + void* got_tag; + bool ok; grpc_slice s = grpc_slice_from_copied_string(request.c_str()); grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); grpc::ByteBuffer send_buffer(&req_slice, 1); - call->Write(send_buffer, tag(2)); - cq.Next(&got_tag, &ok); - GPR_ASSERT(ok); - call->WritesDone(tag(3)); - cq.Next(&got_tag, &ok); + call_->Write(send_buffer, tag(2)); + cq_.Next(&got_tag, &ok); GPR_ASSERT(ok); +} + +bool CliCall::Read(grpc::string* response, + IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + grpc::ByteBuffer recv_buffer; - call->Read(&recv_buffer, tag(4)); - cq.Next(&got_tag, &ok); - if (!ok) { - std::cout << "Failed to read response." << std::endl; + call_->Read(&recv_buffer, tag(3)); + + if (!cq_.Next(&got_tag, &ok) || !ok) { + return false; } - grpc::Status status; - call->Finish(&status, tag(5)); - cq.Next(&got_tag, &ok); + std::vector<grpc::Slice> slices; + recv_buffer.Dump(&slices); + + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast<const char*>(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +void CliCall::WritesDone() { + void* got_tag; + bool ok; + + call_->WritesDone(tag(4)); + cq_.Next(&got_tag, &ok); GPR_ASSERT(ok); +} - if (status.ok()) { - std::vector<grpc::Slice> slices; - (void)recv_buffer.Dump(&slices); +void CliCall::WriteAndWait(const grpc::string& request) { + grpc_slice s = grpc_slice_from_copied_string(request.c_str()); + grpc::Slice req_slice(s, grpc::Slice::STEAL_REF); + grpc::ByteBuffer send_buffer(&req_slice, 1); + + gpr_mu_lock(&write_mu_); + call_->Write(send_buffer, tag(2)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&write_mu_); +} + +void CliCall::WritesDoneAndWait() { + gpr_mu_lock(&write_mu_); + call_->WritesDone(tag(4)); + write_done_ = false; + while (!write_done_) { + gpr_cv_wait(&write_cv_, &write_mu_, gpr_inf_future(GPR_CLOCK_REALTIME)); + } + gpr_mu_unlock(&write_mu_); +} - response->clear(); - for (size_t i = 0; i < slices.size(); i++) { - response->append(reinterpret_cast<const char*>(slices[i].begin()), - slices[i].size()); +bool CliCall::ReadAndMaybeNotifyWrite( + grpc::string* response, + IncomingMetadataContainer* server_initial_metadata) { + void* got_tag; + bool ok; + grpc::ByteBuffer recv_buffer; + + call_->Read(&recv_buffer, tag(3)); + bool cq_result = cq_.Next(&got_tag, &ok); + + while (got_tag != tag(3)) { + gpr_mu_lock(&write_mu_); + write_done_ = true; + gpr_cv_signal(&write_cv_); + gpr_mu_unlock(&write_mu_); + + cq_result = cq_.Next(&got_tag, &ok); + if (got_tag == tag(2)) { + GPR_ASSERT(ok); } } - *server_initial_metadata = ctx.GetServerInitialMetadata(); - *server_trailing_metadata = ctx.GetServerTrailingMetadata(); + if (!cq_result || !ok) { + // If the RPC is ended on the server side, we should still wait for the + // pending write on the client side to be done. + if (!ok) { + gpr_mu_lock(&write_mu_); + if (!write_done_) { + cq_.Next(&got_tag, &ok); + GPR_ASSERT(got_tag != tag(2)); + write_done_ = true; + gpr_cv_signal(&write_cv_); + } + gpr_mu_unlock(&write_mu_); + } + return false; + } + + std::vector<grpc::Slice> slices; + recv_buffer.Dump(&slices); + response->clear(); + for (size_t i = 0; i < slices.size(); i++) { + response->append(reinterpret_cast<const char*>(slices[i].begin()), + slices[i].size()); + } + if (server_initial_metadata) { + *server_initial_metadata = ctx_.GetServerInitialMetadata(); + } + return true; +} + +Status CliCall::Finish(IncomingMetadataContainer* server_trailing_metadata) { + void* got_tag; + bool ok; + grpc::Status status; + + call_->Finish(&status, tag(5)); + cq_.Next(&got_tag, &ok); + GPR_ASSERT(ok); + if (server_trailing_metadata) { + *server_trailing_metadata = ctx_.GetServerTrailingMetadata(); + } + return status; } diff --git a/test/cpp/util/cli_call.h b/test/cpp/util/cli_call.h index 65da86bd4e..91f0dbc9ed 100644 --- a/test/cpp/util/cli_call.h +++ b/test/cpp/util/cli_call.h @@ -37,23 +37,74 @@ #include <map> #include <grpc++/channel.h> +#include <grpc++/completion_queue.h> +#include <grpc++/generic/generic_stub.h> #include <grpc++/support/status.h> #include <grpc++/support/string_ref.h> namespace grpc { + +class ClientContext; + namespace testing { +// CliCall handles the sending and receiving of generic messages given the name +// of the remote method. This class is only used by GrpcTool. Its thread-safe +// and thread-unsafe methods should not be used together. class CliCall final { public: typedef std::multimap<grpc::string, grpc::string> OutgoingMetadataContainer; typedef std::multimap<grpc::string_ref, grpc::string_ref> IncomingMetadataContainer; + + CliCall(std::shared_ptr<grpc::Channel> channel, const grpc::string& method, + const OutgoingMetadataContainer& metadata); + ~CliCall(); + + // Perform an unary generic RPC. static Status Call(std::shared_ptr<grpc::Channel> channel, const grpc::string& method, const grpc::string& request, grpc::string* response, const OutgoingMetadataContainer& metadata, IncomingMetadataContainer* server_initial_metadata, IncomingMetadataContainer* server_trailing_metadata); + + // Send a generic request message in a synchronous manner. NOT thread-safe. + void Write(const grpc::string& request); + + // Send a generic request message in a synchronous manner. NOT thread-safe. + void WritesDone(); + + // Receive a generic response message in a synchronous manner.NOT thread-safe. + bool Read(grpc::string* response, + IncomingMetadataContainer* server_initial_metadata); + + // Thread-safe write. Must be used with ReadAndMaybeNotifyWrite. Send out a + // generic request message and wait for ReadAndMaybeNotifyWrite to finish it. + void WriteAndWait(const grpc::string& request); + + // Thread-safe WritesDone. Must be used with ReadAndMaybeNotifyWrite. Send out + // WritesDone for gereneric request messages and wait for + // ReadAndMaybeNotifyWrite to finish it. + void WritesDoneAndWait(); + + // Thread-safe Read. Blockingly receive a generic response message. Notify + // writes if they are finished when this read is waiting for a resposne. + bool ReadAndMaybeNotifyWrite( + grpc::string* response, + IncomingMetadataContainer* server_initial_metadata); + + // Finish the RPC. + Status Finish(IncomingMetadataContainer* server_trailing_metadata); + + private: + std::unique_ptr<grpc::GenericStub> stub_; + grpc::ClientContext ctx_; + std::unique_ptr<grpc::GenericClientAsyncReaderWriter> call_; + grpc::CompletionQueue cq_; + gpr_mu write_mu_; + gpr_cv write_cv_; // Protected by write_mu_; + bool write_done_; // Portected by write_mu_; }; } // namespace testing diff --git a/test/cpp/util/grpc_cli.cc b/test/cpp/util/grpc_cli.cc index fe248601ee..a78bed4b90 100644 --- a/test/cpp/util/grpc_cli.cc +++ b/test/cpp/util/grpc_cli.cc @@ -83,10 +83,10 @@ DEFINE_string(outfile, "", "Output file (default is stdout)"); static bool SimplePrint(const grpc::string& outfile, const grpc::string& output) { if (outfile.empty()) { - std::cout << output; + std::cout << output << std::endl; } else { - std::ofstream output_file(outfile, std::ios::trunc | std::ios::binary); - output_file << output; + std::ofstream output_file(outfile, std::ios::app | std::ios::binary); + output_file << output << std::endl; output_file.close(); } return true; diff --git a/test/cpp/util/grpc_tool.cc b/test/cpp/util/grpc_tool.cc index b9900ca1b7..39acd8eb4b 100644 --- a/test/cpp/util/grpc_tool.cc +++ b/test/cpp/util/grpc_tool.cc @@ -39,6 +39,7 @@ #include <memory> #include <sstream> #include <string> +#include <thread> #include <gflags/gflags.h> #include <grpc++/channel.h> @@ -159,6 +160,36 @@ void PrintMetadata(const T& m, const grpc::string& message) { } } +void ReadResponse(CliCall* call, const grpc::string& method_name, + GrpcToolOutputCallback callback, ProtoFileParser* parser, + gpr_mu* parser_mu, bool print_mode) { + grpc::string serialized_response_proto; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata; + + for (bool receive_initial_metadata = true; call->ReadAndMaybeNotifyWrite( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + fprintf(stderr, "got response.\n"); + if (!FLAGS_binary_output) { + gpr_mu_lock(parser_mu); + serialized_response_proto = parser->GetTextFormatFromMethod( + method_name, serialized_response_proto, false /* is_request */); + if (parser->HasError() && print_mode) { + fprintf(stderr, "Failed to parse response.\n"); + } + gpr_mu_unlock(parser_mu); + } + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto) && print_mode) { + fprintf(stderr, "Failed to output response.\n"); + } + } +} + struct Command { const char* command; std::function<bool(GrpcTool*, int, const char**, const CliCredentials&, @@ -416,85 +447,191 @@ bool GrpcTool::CallMethod(int argc, const char** argv, grpc::string server_address(argv[0]); grpc::string method_name(argv[1]); grpc::string formatted_method_name; - std::unique_ptr<grpc::testing::ProtoFileParser> parser; + std::unique_ptr<ProtoFileParser> parser; grpc::string serialized_request_proto; + bool print_mode = false; - if (argc == 3) { - request_text = argv[2]; - if (!FLAGS_infile.empty()) { - fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); - } + std::shared_ptr<grpc::Channel> channel = + FLAGS_remotedb + ? grpc::CreateChannel(server_address, cred.GetCredentials()) + : nullptr; + + parser.reset(new grpc::testing::ProtoFileParser(channel, FLAGS_proto_path, + FLAGS_protofiles)); + + if (FLAGS_binary_input) { + formatted_method_name = method_name; } else { - std::stringstream input_stream; + formatted_method_name = parser->GetFormattedMethodName(method_name); + } + + if (parser->HasError()) { + return false; + } + + if (parser->IsStreaming(method_name, true /* is_request */)) { + std::istream* input_stream; + std::ifstream input_file; + + if (argc == 3) { + request_text = argv[2]; + } + + std::multimap<grpc::string, grpc::string> client_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata); + if (FLAGS_infile.empty()) { if (isatty(STDIN_FILENO)) { - fprintf(stderr, "reading request message from stdin...\n"); + print_mode = true; + fprintf(stderr, "reading streaming request message from stdin...\n"); } - input_stream << std::cin.rdbuf(); + input_stream = &std::cin; } else { - std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); - input_stream << input_file.rdbuf(); + input_file.open(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream = &input_file; + } + + gpr_mu parser_mu; + gpr_mu_init(&parser_mu); + std::thread read_thread(ReadResponse, &call, method_name, callback, + parser.get(), &parser_mu, print_mode); + + std::stringstream request_ss; + grpc::string line; + while (!request_text.empty() || + (!input_stream->eof() && getline(*input_stream, line))) { + if (!request_text.empty()) { + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + request_text.clear(); + } else { + gpr_mu_lock(&parser_mu); + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */); + request_text.clear(); + if (parser->HasError()) { + if (print_mode) { + fprintf(stderr, "Failed to parse request.\n"); + } + gpr_mu_unlock(&parser_mu); + continue; + } + gpr_mu_unlock(&parser_mu); + } + + call.WriteAndWait(serialized_request_proto); + if (print_mode) { + fprintf(stderr, "Request sent.\n"); + } + } else { + if (line.length() == 0) { + request_text = request_ss.str(); + request_ss.str(grpc::string()); + request_ss.clear(); + } else { + request_ss << line << ' '; + } + } + } + if (input_file.is_open()) { input_file.close(); } - request_text = input_stream.str(); - } - std::shared_ptr<grpc::Channel> channel = - grpc::CreateChannel(server_address, cred.GetCredentials()); - if (!FLAGS_binary_input || !FLAGS_binary_output) { - parser.reset( - new grpc::testing::ProtoFileParser(FLAGS_remotedb ? channel : nullptr, - FLAGS_proto_path, FLAGS_protofiles)); - if (parser->HasError()) { + call.WritesDoneAndWait(); + read_thread.join(); + + std::multimap<grpc::string_ref, grpc::string_ref> server_trailing_metadata; + Status status = call.Finish(&server_trailing_metadata); + PrintMetadata(server_trailing_metadata, + "Received trailing metadata from server:"); + + if (status.ok()) { + fprintf(stderr, "Stream RPC succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); return false; } - } - if (FLAGS_binary_input) { - serialized_request_proto = request_text; - formatted_method_name = method_name; - } else { - formatted_method_name = parser->GetFormattedMethodName(method_name); - serialized_request_proto = parser->GetSerializedProtoFromMethod( - method_name, request_text, true /* is_request */); - if (parser->HasError()) { - return false; + } else { // parser->IsStreaming(method_name, true /* is_request */) + if (argc == 3) { + request_text = argv[2]; + if (!FLAGS_infile.empty()) { + fprintf(stderr, "warning: request given in argv, ignoring --infile\n"); + } + } else { + std::stringstream input_stream; + if (FLAGS_infile.empty()) { + if (isatty(STDIN_FILENO)) { + fprintf(stderr, "reading request message from stdin...\n"); + } + input_stream << std::cin.rdbuf(); + } else { + std::ifstream input_file(FLAGS_infile, std::ios::in | std::ios::binary); + input_stream << input_file.rdbuf(); + input_file.close(); + } + request_text = input_stream.str(); } - } - fprintf(stderr, "connecting to %s\n", server_address.c_str()); - grpc::string serialized_response_proto; - std::multimap<grpc::string, grpc::string> client_metadata; - std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, - server_trailing_metadata; - ParseMetadataFlag(&client_metadata); - PrintMetadata(client_metadata, "Sending client initial metadata:"); - grpc::Status status = grpc::testing::CliCall::Call( - channel, formatted_method_name, serialized_request_proto, - &serialized_response_proto, client_metadata, &server_initial_metadata, - &server_trailing_metadata); - PrintMetadata(server_initial_metadata, - "Received initial metadata from server:"); - PrintMetadata(server_trailing_metadata, - "Received trailing metadata from server:"); - if (status.ok()) { - fprintf(stderr, "Rpc succeeded with OK status\n"); - if (FLAGS_binary_output) { - output_ss << serialized_response_proto; + if (FLAGS_binary_input) { + serialized_request_proto = request_text; + // formatted_method_name = method_name; } else { - grpc::string response_text = parser->GetTextFormatFromMethod( - method_name, serialized_response_proto, false /* is_request */); + // formatted_method_name = parser->GetFormattedMethodName(method_name); + serialized_request_proto = parser->GetSerializedProtoFromMethod( + method_name, request_text, true /* is_request */); if (parser->HasError()) { return false; } - output_ss << "Response: \n " << response_text << std::endl; } - } else { - fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", - status.error_code(), status.error_message().c_str()); + fprintf(stderr, "connecting to %s\n", server_address.c_str()); + + grpc::string serialized_response_proto; + std::multimap<grpc::string, grpc::string> client_metadata; + std::multimap<grpc::string_ref, grpc::string_ref> server_initial_metadata, + server_trailing_metadata; + ParseMetadataFlag(&client_metadata); + PrintMetadata(client_metadata, "Sending client initial metadata:"); + + CliCall call(channel, formatted_method_name, client_metadata); + call.Write(serialized_request_proto); + call.WritesDone(); + + for (bool receive_initial_metadata = true; call.Read( + &serialized_response_proto, + receive_initial_metadata ? &server_initial_metadata : nullptr); + receive_initial_metadata = false) { + if (!FLAGS_binary_output) { + serialized_response_proto = parser->GetTextFormatFromMethod( + method_name, serialized_response_proto, false /* is_request */); + if (parser->HasError()) { + return false; + } + } + if (receive_initial_metadata) { + PrintMetadata(server_initial_metadata, + "Received initial metadata from server:"); + } + if (!callback(serialized_response_proto)) { + return false; + } + } + Status status = call.Finish(&server_trailing_metadata); + if (status.ok()) { + fprintf(stderr, "Rpc succeeded with OK status\n"); + return true; + } else { + fprintf(stderr, "Rpc failed with status code %d, error message: %s\n", + status.error_code(), status.error_message().c_str()); + return false; + } } - - return callback(output_ss.str()); + GPR_UNREACHABLE_CODE(return false); } bool GrpcTool::ParseMessage(int argc, const char** argv, diff --git a/test/cpp/util/grpc_tool_test.cc b/test/cpp/util/grpc_tool_test.cc index 33ce611a60..26e2b1f502 100644 --- a/test/cpp/util/grpc_tool_test.cc +++ b/test/cpp/util/grpc_tool_test.cc @@ -102,6 +102,8 @@ DECLARE_bool(l); namespace { +const int kNumResponseStreamsMsgs = 3; + class TestCliCredentials final : public grpc::testing::CliCredentials { public: std::shared_ptr<grpc::ChannelCredentials> GetCredentials() const override { @@ -137,6 +139,71 @@ class TestServiceImpl : public ::grpc::testing::EchoTestService::Service { response->set_message(request->message()); return Status::OK; } + + Status RequestStream(ServerContext* context, + ServerReader<EchoRequest>* reader, + EchoResponse* response) override { + EchoRequest request; + response->set_message(""); + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + while (reader->Read(&request)) { + response->mutable_message()->append(request.message()); + } + + return Status::OK; + } + + Status ResponseStream(ServerContext* context, const EchoRequest* request, + ServerWriter<EchoResponse>* writer) override { + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + EchoResponse response; + for (int i = 0; i < kNumResponseStreamsMsgs; i++) { + response.set_message(request->message() + grpc::to_string(i)); + writer->Write(response); + } + + return Status::OK; + } + + Status BidiStream( + ServerContext* context, + ServerReaderWriter<EchoResponse, EchoRequest>* stream) override { + EchoRequest request; + EchoResponse response; + if (!context->client_metadata().empty()) { + for (std::multimap<grpc::string_ref, grpc::string_ref>::const_iterator + iter = context->client_metadata().begin(); + iter != context->client_metadata().end(); ++iter) { + context->AddInitialMetadata(ToString(iter->first), + ToString(iter->second)); + } + } + context->AddTrailingMetadata("trailing_key", "trailing_value"); + + while (stream->Read(&request)) { + response.set_message(request.message()); + stream->Write(response); + } + + return Status::OK; + } }; } // namespace @@ -347,6 +414,132 @@ TEST_F(GrpcToolTest, CallCommand) { ShutdownServer(); } +TEST_F(GrpcToolTest, CallCommandRequestStream) { + // Test input: grpc_cli call localhost:<port> RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const grpc::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello1Hello2\"" + EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(), + "message: \"Hello0Hello1Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandRequestStreamWithBadRequest) { + // Test input: grpc_cli call localhost:<port> RequestStream "message: + // 'Hello0'" + std::stringstream output_stream; + + const grpc::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "RequestStream", "message: 'Hello0'"}; + + // Mock std::cin input "bad_field: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("bad_field: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0Hello2\"" + EXPECT_TRUE(NULL != + strstr(output_stream.str().c_str(), "message: \"Hello0Hello2\"")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandResponseStream) { + // Test input: grpc_cli call localhost:<port> ResponseStream "message: + // 'Hello'" + std::stringstream output_stream; + + const grpc::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "ResponseStream", "message: 'Hello'"}; + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello{n}\"" + for (int i = 0; i < kNumResponseStreamsMsgs; i++) { + grpc::string expected_response_text = + "message: \"Hello" + grpc::to_string(i) + "\"\n"; + EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(), + expected_response_text.c_str())); + } + + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStream) { + // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const grpc::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 'Hello1'\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: " + "\"Hello1\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + ShutdownServer(); +} + +TEST_F(GrpcToolTest, CallCommandBidiStreamWithBadRequest) { + // Test input: grpc_cli call localhost:<port> BidiStream "message: 'Hello0'" + std::stringstream output_stream; + + const grpc::string server_address = SetUpServer(); + const char* argv[] = {"grpc_cli", "call", server_address.c_str(), + "BidiStream", "message: 'Hello0'"}; + + // Mock std::cin input "message: 'Hello1'\n\n message: 'Hello2'\n\n" + std::streambuf* orig = std::cin.rdbuf(); + std::istringstream ss("message: 1.0\n\n message: 'Hello2'\n\n"); + std::cin.rdbuf(ss.rdbuf()); + + EXPECT_TRUE(0 == GrpcToolMainLib(ArraySize(argv), argv, TestCliCredentials(), + std::bind(PrintStream, &output_stream, + std::placeholders::_1))); + + // Expected output: "message: \"Hello0\"\nmessage: \"Hello1\"\nmessage: + // \"Hello2\"\n\n" + EXPECT_TRUE(NULL != strstr(output_stream.str().c_str(), + "message: \"Hello0\"\nmessage: \"Hello2\"\n")); + std::cin.rdbuf(orig); + + ShutdownServer(); +} + TEST_F(GrpcToolTest, ParseCommand) { // Test input "grpc_cli parse localhost:<port> grpc.testing.EchoResponse // ECHO_RESPONSE_MESSAGE" diff --git a/test/cpp/util/proto_file_parser.cc b/test/cpp/util/proto_file_parser.cc index bc8a6083f4..d501c3697b 100644 --- a/test/cpp/util/proto_file_parser.cc +++ b/test/cpp/util/proto_file_parser.cc @@ -81,8 +81,9 @@ class ErrorPrinter : public protobuf::compiler::MultiFileErrorCollector { ProtoFileParser::ProtoFileParser(std::shared_ptr<grpc::Channel> channel, const grpc::string& proto_path, const grpc::string& protofiles) - : has_error_(false) { - std::vector<grpc::string> service_list; + : has_error_(false), + dynamic_factory_(new protobuf::DynamicMessageFactory()) { + std::vector<std::string> service_list; if (channel) { reflection_db_.reset(new grpc::ProtoReflectionDescriptorDatabase(channel)); reflection_db_->GetServices(&service_list); @@ -127,7 +128,6 @@ ProtoFileParser::ProtoFileParser(std::shared_ptr<grpc::Channel> channel, } desc_pool_.reset(new protobuf::DescriptorPool(desc_db_.get())); - dynamic_factory_.reset(new protobuf::DynamicMessageFactory(desc_pool_.get())); for (auto it = service_list.begin(); it != service_list.end(); it++) { if (known_services.find(*it) == known_services.end()) { @@ -144,6 +144,11 @@ ProtoFileParser::~ProtoFileParser() {} grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { has_error_ = false; + + if (known_methods_.find(method) != known_methods_.end()) { + return known_methods_[method]; + } + const protobuf::MethodDescriptor* method_descriptor = nullptr; for (auto it = service_desc_list_.begin(); it != service_desc_list_.end(); it++) { @@ -169,6 +174,8 @@ grpc::string ProtoFileParser::GetFullMethodName(const grpc::string& method) { return ""; } + known_methods_[method] = method_descriptor->full_name(); + return method_descriptor->full_name(); } @@ -205,6 +212,25 @@ grpc::string ProtoFileParser::GetMessageTypeFromMethod( : method_desc->output_type()->full_name(); } +bool ProtoFileParser::IsStreaming(const grpc::string& method, bool is_request) { + has_error_ = false; + + grpc::string full_method_name = GetFullMethodName(method); + if (has_error_) { + return false; + } + + const protobuf::MethodDescriptor* method_desc = + desc_pool_->FindMethodByName(full_method_name); + if (!method_desc) { + LogError("Method not found"); + return false; + } + + return is_request ? method_desc->client_streaming() + : method_desc->server_streaming(); +} + grpc::string ProtoFileParser::GetSerializedProtoFromMethod( const grpc::string& method, const grpc::string& text_format_proto, bool is_request) { diff --git a/test/cpp/util/proto_file_parser.h b/test/cpp/util/proto_file_parser.h index c1070a37b5..23d311ef8f 100644 --- a/test/cpp/util/proto_file_parser.h +++ b/test/cpp/util/proto_file_parser.h @@ -84,6 +84,8 @@ class ProtoFileParser { const grpc::string& message_type_name, const grpc::string& serialized_proto); + bool IsStreaming(const grpc::string& method, bool is_request); + bool HasError() const { return has_error_; } void LogError(const grpc::string& error_msg); @@ -104,6 +106,7 @@ class ProtoFileParser { std::unique_ptr<protobuf::DynamicMessageFactory> dynamic_factory_; std::unique_ptr<grpc::protobuf::Message> request_prototype_; std::unique_ptr<grpc::protobuf::Message> response_prototype_; + std::unordered_map<grpc::string, grpc::string> known_methods_; std::vector<const protobuf::ServiceDescriptor*> service_desc_list_; }; |