diff options
-rw-r--r-- | include/grpc++/impl/call.h | 4 | ||||
-rw-r--r-- | include/grpc++/server_context.h | 2 | ||||
-rw-r--r-- | include/grpc++/stream.h | 55 | ||||
-rw-r--r-- | src/compiler/cpp_generator.cc | 6 | ||||
-rw-r--r-- | src/cpp/client/client_unary_call.cc | 9 | ||||
-rw-r--r-- | src/cpp/common/call.cc | 11 | ||||
-rw-r--r-- | test/cpp/end2end/async_end2end_test.cc | 172 |
7 files changed, 190 insertions, 69 deletions
diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h index 7aa22ee7c2..af1c710098 100644 --- a/include/grpc++/impl/call.h +++ b/include/grpc++/impl/call.h @@ -68,7 +68,7 @@ class CallOpBuffer : public CompletionQueueTag { void AddRecvInitialMetadata( std::multimap<grpc::string, grpc::string> *metadata); void AddSendMessage(const google::protobuf::Message &message); - void AddRecvMessage(google::protobuf::Message *message, bool* got_message); + void AddRecvMessage(google::protobuf::Message *message); void AddClientSendClose(); void AddClientRecvStatus(std::multimap<grpc::string, grpc::string> *metadata, Status *status); @@ -84,6 +84,7 @@ class CallOpBuffer : public CompletionQueueTag { // Called by completion queue just prior to returning from Next() or Pluck() void FinalizeResult(void **tag, bool *status) override; + bool got_message = false; private: void *return_tag_ = nullptr; // Send initial metadata @@ -98,7 +99,6 @@ class CallOpBuffer : public CompletionQueueTag { grpc_byte_buffer* send_message_buf_ = nullptr; // Recv message google::protobuf::Message* recv_message_ = nullptr; - bool* got_message_ = nullptr; grpc_byte_buffer* recv_message_buf_ = nullptr; // Client send close bool client_send_close_ = false; diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h index 64091a4505..423ebf2337 100644 --- a/include/grpc++/server_context.h +++ b/include/grpc++/server_context.h @@ -45,7 +45,7 @@ struct grpc_call; namespace grpc { -template <class R> +template <class W, class R> class ServerAsyncReader; template <class W> class ServerAsyncWriter; diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h index 359a272e7b..6ee550bd64 100644 --- a/include/grpc++/stream.h +++ b/include/grpc++/stream.h @@ -119,10 +119,9 @@ class ClientReader final : public ClientStreamingInterface, buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_); context_->initial_metadata_received_ = true; } - bool got_message; - buf.AddRecvMessage(msg, &got_message); + buf.AddRecvMessage(msg); call_.PerformOps(&buf); - return cq_.Pluck(&buf) && got_message; + return cq_.Pluck(&buf) && buf.got_message; } virtual Status Finish() override { @@ -174,11 +173,10 @@ class ClientWriter final : public ClientStreamingInterface, virtual Status Finish() override { CallOpBuffer buf; Status status; - bool got_message; - buf.AddRecvMessage(response_, &got_message); + buf.AddRecvMessage(response_); buf.AddClientRecvStatus(&context_->trailing_metadata_, &status); call_.PerformOps(&buf); - GPR_ASSERT(cq_.Pluck(&buf) && got_message); + GPR_ASSERT(cq_.Pluck(&buf) && buf.got_message); return status; } @@ -225,10 +223,9 @@ class ClientReaderWriter final : public ClientStreamingInterface, buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_); context_->initial_metadata_received_ = true; } - bool got_message; - buf.AddRecvMessage(msg, &got_message); + buf.AddRecvMessage(msg); call_.PerformOps(&buf); - return cq_.Pluck(&buf) && got_message; + return cq_.Pluck(&buf) && buf.got_message; } virtual bool Write(const W& msg) override { @@ -277,10 +274,9 @@ class ServerReader final : public ReaderInterface<R> { virtual bool Read(R* msg) override { CallOpBuffer buf; - bool got_message; - buf.AddRecvMessage(msg, &got_message); + buf.AddRecvMessage(msg); call_->PerformOps(&buf); - return call_->cq()->Pluck(&buf) && got_message; + return call_->cq()->Pluck(&buf) && buf.got_message; } private: @@ -338,10 +334,9 @@ class ServerReaderWriter final : public WriterInterface<W>, virtual bool Read(R* msg) override { CallOpBuffer buf; - bool got_message; - buf.AddRecvMessage(msg, &got_message); + buf.AddRecvMessage(msg); call_->PerformOps(&buf); - return call_->cq()->Pluck(&buf) && got_message; + return call_->cq()->Pluck(&buf) && buf.got_message; } virtual bool Write(const W& msg) override { @@ -420,8 +415,7 @@ class ClientAsyncReader final : public ClientAsyncStreamingInterface, read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_); context_->initial_metadata_received_ = true; } - bool ignore; - read_buf_.AddRecvMessage(msg, &ignore); + read_buf_.AddRecvMessage(msg); call_.PerformOps(&read_buf_); } @@ -485,8 +479,7 @@ class ClientAsyncWriter final : public ClientAsyncStreamingInterface, finish_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_); context_->initial_metadata_received_ = true; } - bool ignore; - finish_buf_.AddRecvMessage(response_, &ignore); + finish_buf_.AddRecvMessage(response_); finish_buf_.AddClientRecvStatus(&context_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } @@ -494,7 +487,6 @@ class ClientAsyncWriter final : public ClientAsyncStreamingInterface, private: ClientContext* context_ = nullptr; google::protobuf::Message *const response_; - bool got_message_; Call call_; CallOpBuffer init_buf_; CallOpBuffer meta_buf_; @@ -532,8 +524,7 @@ class ClientAsyncReaderWriter final : public ClientAsyncStreamingInterface, read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_); context_->initial_metadata_received_ = true; } - bool ignore; - read_buf_.AddRecvMessage(msg, &ignore); + read_buf_.AddRecvMessage(msg); call_.PerformOps(&read_buf_); } @@ -624,7 +615,7 @@ class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface { CallOpBuffer finish_buf_; }; -template <class R> +template <class W, class R> class ServerAsyncReader : public ServerAsyncStreamingInterface, public AsyncReaderInterface<R> { public: @@ -646,18 +637,34 @@ class ServerAsyncReader : public ServerAsyncStreamingInterface, call_.PerformOps(&read_buf_); } - void Finish(const Status& status, void* tag) { + void Finish(const W& msg, const Status& status, void* tag) { finish_buf_.Reset(tag); if (!ctx_->sent_initial_metadata_) { finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); ctx_->sent_initial_metadata_ = true; } + // The response is dropped if the status is not OK. + if (status.IsOk()) { + finish_buf_.AddSendMessage(msg); + } bool cancelled = false; finish_buf_.AddServerRecvClose(&cancelled); finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); call_.PerformOps(&finish_buf_); } + void FinishWithError(const Status& status, void* tag) { + GPR_ASSERT(!status.IsOk()); + finish_buf_.Reset(tag); + if (!ctx_->sent_initial_metadata_) { + finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_); + ctx_->sent_initial_metadata_ = true; + } + bool cancelled = false; + finish_buf_.AddServerRecvClose(&cancelled); + finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status); + call_.PerformOps(&finish_buf_); + } private: void BindCall(Call *call) override { call_ = *call; } diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc index a34aa4e568..2a9895e43c 100644 --- a/src/compiler/cpp_generator.cc +++ b/src/compiler/cpp_generator.cc @@ -133,7 +133,7 @@ std::string GetHeaderIncludes(const google::protobuf::FileDescriptor *file) { temp.append("template <class OutMessage> class ClientWriter;\n"); temp.append("template <class InMessage> class ServerReader;\n"); temp.append("template <class OutMessage> class ClientAsyncWriter;\n"); - temp.append("template <class InMessage> class ServerAsyncReader;\n"); + temp.append("template <class OutMessage, class InMessage> class ServerAsyncReader;\n"); } if (HasServerOnlyStreaming(file)) { temp.append("template <class InMessage> class ClientReader;\n"); @@ -267,7 +267,7 @@ void PrintHeaderServerMethodAsync( printer->Print(*vars, "void Request$Method$(" "::grpc::ServerContext* context, " - "::grpc::ServerAsyncReader< $Request$>* reader, " + "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, " "::grpc::CompletionQueue* cq, void *tag);\n"); } else if (ServerOnlyStreaming(method)) { printer->Print(*vars, @@ -538,7 +538,7 @@ void PrintSourceServerAsyncMethod( printer->Print(*vars, "void $Service$::AsyncService::Request$Method$(" "::grpc::ServerContext* context, " - "::grpc::ServerAsyncReader< $Request$>* reader, " + "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, " "::grpc::CompletionQueue* cq, void* tag) {\n"); printer->Print( *vars, diff --git a/src/cpp/client/client_unary_call.cc b/src/cpp/client/client_unary_call.cc index b6bd81d93f..d68d7a9242 100644 --- a/src/cpp/client/client_unary_call.cc +++ b/src/cpp/client/client_unary_call.cc @@ -53,21 +53,18 @@ Status BlockingUnaryCall(ChannelInterface *channel, const RpcMethod &method, buf.AddSendInitialMetadata(context); buf.AddSendMessage(request); buf.AddRecvInitialMetadata(&context->recv_initial_metadata_); - bool got_message; - buf.AddRecvMessage(result, &got_message); + buf.AddRecvMessage(result); buf.AddClientSendClose(); buf.AddClientRecvStatus(&context->trailing_metadata_, &status); call.PerformOps(&buf); - GPR_ASSERT(cq.Pluck(&buf) && (got_message || !status.IsOk())); + GPR_ASSERT(cq.Pluck(&buf) && (buf.got_message || !status.IsOk())); return status; } class ClientAsyncRequest final : public CallOpBuffer { public: - bool got_message = false; void FinalizeResult(void** tag, bool* status) override { CallOpBuffer::FinalizeResult(tag, status); - *status &= got_message; delete this; } }; @@ -83,7 +80,7 @@ void AsyncUnaryCall(ChannelInterface *channel, const RpcMethod &method, buf->AddSendInitialMetadata(context); buf->AddSendMessage(request); buf->AddRecvInitialMetadata(&context->recv_initial_metadata_); - buf->AddRecvMessage(result, &buf->got_message); + buf->AddRecvMessage(result); buf->AddClientSendClose(); buf->AddClientRecvStatus(&context->trailing_metadata_, status); call.PerformOps(buf); diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc index d706ec45e5..fe8859de94 100644 --- a/src/cpp/common/call.cc +++ b/src/cpp/common/call.cc @@ -57,7 +57,7 @@ void CallOpBuffer::Reset(void* next_return_tag) { } recv_message_ = nullptr; - got_message_ = nullptr; + got_message = false; if (recv_message_buf_) { grpc_byte_buffer_destroy(recv_message_buf_); recv_message_buf_ = nullptr; @@ -142,9 +142,8 @@ void CallOpBuffer::AddSendMessage(const google::protobuf::Message& message) { send_message_ = &message; } -void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message, bool* got_message) { +void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message) { recv_message_ = message; - got_message_ = got_message; } void CallOpBuffer::AddClientSendClose() { @@ -256,12 +255,14 @@ void CallOpBuffer::FinalizeResult(void **tag, bool *status) { // Parse received message if any. if (recv_message_) { if (recv_message_buf_) { - *got_message_ = true; + got_message = true; *status = DeserializeProto(recv_message_buf_, recv_message_); grpc_byte_buffer_destroy(recv_message_buf_); recv_message_buf_ = nullptr; } else { - *got_message_ = false; + // Read failed + got_message = false; + *status = false; } } // Parse received status. diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc index 52fb80e8db..b85aabf09e 100644 --- a/test/cpp/end2end/async_end2end_test.cc +++ b/test/cpp/end2end/async_end2end_test.cc @@ -64,9 +64,21 @@ namespace testing { namespace { +void* tag(int i) { + return (void*)(gpr_intptr)i; +} + +void verify_ok(CompletionQueue* cq, int i, bool expect_ok) { + bool ok; + void* got_tag; + EXPECT_TRUE(cq->Next(&got_tag, &ok)); + EXPECT_EQ(expect_ok, ok); + EXPECT_EQ(tag(i), got_tag); +} + class End2endTest : public ::testing::Test { protected: - End2endTest() : service_(&cq_) {} + End2endTest() : service_(&srv_cq_) {} void SetUp() override { int port = grpc_pick_unused_port_or_die(); @@ -86,20 +98,30 @@ class End2endTest : public ::testing::Test { stub_.reset(grpc::cpp::test::util::TestService::NewStub(channel)); } - CompletionQueue cq_; + void server_ok(int i) { + verify_ok(&srv_cq_, i, true); + } + void client_ok(int i) { + verify_ok(&cli_cq_, i , true); + } + void server_fail(int i) { + verify_ok(&srv_cq_, i, false); + } + void client_fail(int i) { + verify_ok(&cli_cq_, i, false); + } + + CompletionQueue cli_cq_; + CompletionQueue srv_cq_; std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_; std::unique_ptr<Server> server_; grpc::cpp::test::util::TestService::AsyncService service_; std::ostringstream server_address_; }; -void* tag(int i) { - return (void*)(gpr_intptr)i; -} - TEST_F(End2endTest, SimpleRpc) { ResetStub(); - + EchoRequest send_request; EchoRequest recv_request; EchoResponse send_response; @@ -110,34 +132,128 @@ TEST_F(End2endTest, SimpleRpc) { grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx); send_request.set_message("Hello"); - stub_->Echo(&cli_ctx, send_request, &recv_response, &recv_status, &cq_, tag(1)); + stub_->Echo( + &cli_ctx, send_request, &recv_response, &recv_status, &cli_cq_, tag(1)); - service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, &cq_, tag(2)); + service_.RequestEcho( + &srv_ctx, &recv_request, &response_writer, &srv_cq_, tag(2)); - void *got_tag; - bool ok; - EXPECT_TRUE(cq_.Next(&got_tag, &ok)); - EXPECT_TRUE(ok); - EXPECT_EQ(got_tag, tag(2)); - EXPECT_EQ(recv_request.message(), "Hello"); + server_ok(2); + EXPECT_EQ(send_request.message(), recv_request.message()); send_response.set_message(recv_request.message()); response_writer.Finish(send_response, Status::OK, tag(3)); - EXPECT_TRUE(cq_.Next(&got_tag, &ok)); - EXPECT_TRUE(ok); - if (got_tag == tag(3)) { - EXPECT_TRUE(cq_.Next(&got_tag, &ok)); - EXPECT_TRUE(ok); - EXPECT_EQ(got_tag, tag(1)); - } else { - EXPECT_EQ(got_tag, tag(1)); - EXPECT_TRUE(cq_.Next(&got_tag, &ok)); - EXPECT_TRUE(ok); - EXPECT_EQ(got_tag, tag(3)); - } + server_ok(3); + + client_ok(1); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.IsOk()); +} + +TEST_F(End2endTest, SimpleClientStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + ClientAsyncWriter<EchoRequest>* cli_stream = + stub_->RequestStream(&cli_ctx, &recv_response, &cli_cq_, tag(1)); + + service_.RequestRequestStream( + &srv_ctx, &srv_stream, &srv_cq_, tag(2)); + + server_ok(2); + client_ok(1); + + cli_stream->Write(send_request, tag(3)); + client_ok(3); + + srv_stream.Read(&recv_request, tag(4)); + server_ok(4); + EXPECT_EQ(send_request.message(), recv_request.message()); + + cli_stream->Write(send_request, tag(5)); + client_ok(5); + + srv_stream.Read(&recv_request, tag(6)); + server_ok(6); + + EXPECT_EQ(send_request.message(), recv_request.message()); + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_request, tag(8)); + server_fail(8); + + send_response.set_message(recv_request.message()); + srv_stream.Finish(send_response, Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); + + EXPECT_EQ(send_response.message(), recv_response.message()); + EXPECT_TRUE(recv_status.IsOk()); +} + +TEST_F(End2endTest, SimpleBidiStreaming) { + ResetStub(); + + EchoRequest send_request; + EchoRequest recv_request; + EchoResponse send_response; + EchoResponse recv_response; + Status recv_status; + ClientContext cli_ctx; + ServerContext srv_ctx; + ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx); + + send_request.set_message("Hello"); + ClientAsyncReaderWriter<EchoRequest, EchoResponse>* cli_stream = + stub_->BidiStream(&cli_ctx, &cli_cq_, tag(1)); + + service_.RequestBidiStream( + &srv_ctx, &srv_stream, &srv_cq_, tag(2)); + + server_ok(2); + client_ok(1); + + cli_stream->Write(send_request, tag(3)); + client_ok(3); + + srv_stream.Read(&recv_request, tag(4)); + server_ok(4); + EXPECT_EQ(send_request.message(), recv_request.message()); + + send_response.set_message(recv_request.message()); + srv_stream.Write(send_response, tag(5)); + server_ok(5); + + cli_stream->Read(&recv_response, tag(6)); + client_ok(6); + EXPECT_EQ(send_response.message(), recv_response.message()); + + cli_stream->WritesDone(tag(7)); + client_ok(7); + + srv_stream.Read(&recv_request, tag(8)); + server_fail(8); + + srv_stream.Finish(Status::OK, tag(9)); + server_ok(9); + + cli_stream->Finish(&recv_status, tag(10)); + client_ok(10); - EXPECT_EQ(recv_response.message(), "Hello"); EXPECT_TRUE(recv_status.IsOk()); } |