aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--include/grpc++/impl/call.h4
-rw-r--r--include/grpc++/server_context.h2
-rw-r--r--include/grpc++/stream.h55
-rw-r--r--src/compiler/cpp_generator.cc6
-rw-r--r--src/cpp/client/client_unary_call.cc9
-rw-r--r--src/cpp/common/call.cc11
-rw-r--r--test/cpp/end2end/async_end2end_test.cc172
7 files changed, 190 insertions, 69 deletions
diff --git a/include/grpc++/impl/call.h b/include/grpc++/impl/call.h
index 7aa22ee7c2..af1c710098 100644
--- a/include/grpc++/impl/call.h
+++ b/include/grpc++/impl/call.h
@@ -68,7 +68,7 @@ class CallOpBuffer : public CompletionQueueTag {
void AddRecvInitialMetadata(
std::multimap<grpc::string, grpc::string> *metadata);
void AddSendMessage(const google::protobuf::Message &message);
- void AddRecvMessage(google::protobuf::Message *message, bool* got_message);
+ void AddRecvMessage(google::protobuf::Message *message);
void AddClientSendClose();
void AddClientRecvStatus(std::multimap<grpc::string, grpc::string> *metadata,
Status *status);
@@ -84,6 +84,7 @@ class CallOpBuffer : public CompletionQueueTag {
// Called by completion queue just prior to returning from Next() or Pluck()
void FinalizeResult(void **tag, bool *status) override;
+ bool got_message = false;
private:
void *return_tag_ = nullptr;
// Send initial metadata
@@ -98,7 +99,6 @@ class CallOpBuffer : public CompletionQueueTag {
grpc_byte_buffer* send_message_buf_ = nullptr;
// Recv message
google::protobuf::Message* recv_message_ = nullptr;
- bool* got_message_ = nullptr;
grpc_byte_buffer* recv_message_buf_ = nullptr;
// Client send close
bool client_send_close_ = false;
diff --git a/include/grpc++/server_context.h b/include/grpc++/server_context.h
index 64091a4505..423ebf2337 100644
--- a/include/grpc++/server_context.h
+++ b/include/grpc++/server_context.h
@@ -45,7 +45,7 @@ struct grpc_call;
namespace grpc {
-template <class R>
+template <class W, class R>
class ServerAsyncReader;
template <class W>
class ServerAsyncWriter;
diff --git a/include/grpc++/stream.h b/include/grpc++/stream.h
index 359a272e7b..6ee550bd64 100644
--- a/include/grpc++/stream.h
+++ b/include/grpc++/stream.h
@@ -119,10 +119,9 @@ class ClientReader final : public ClientStreamingInterface,
buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
context_->initial_metadata_received_ = true;
}
- bool got_message;
- buf.AddRecvMessage(msg, &got_message);
+ buf.AddRecvMessage(msg);
call_.PerformOps(&buf);
- return cq_.Pluck(&buf) && got_message;
+ return cq_.Pluck(&buf) && buf.got_message;
}
virtual Status Finish() override {
@@ -174,11 +173,10 @@ class ClientWriter final : public ClientStreamingInterface,
virtual Status Finish() override {
CallOpBuffer buf;
Status status;
- bool got_message;
- buf.AddRecvMessage(response_, &got_message);
+ buf.AddRecvMessage(response_);
buf.AddClientRecvStatus(&context_->trailing_metadata_, &status);
call_.PerformOps(&buf);
- GPR_ASSERT(cq_.Pluck(&buf) && got_message);
+ GPR_ASSERT(cq_.Pluck(&buf) && buf.got_message);
return status;
}
@@ -225,10 +223,9 @@ class ClientReaderWriter final : public ClientStreamingInterface,
buf.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
context_->initial_metadata_received_ = true;
}
- bool got_message;
- buf.AddRecvMessage(msg, &got_message);
+ buf.AddRecvMessage(msg);
call_.PerformOps(&buf);
- return cq_.Pluck(&buf) && got_message;
+ return cq_.Pluck(&buf) && buf.got_message;
}
virtual bool Write(const W& msg) override {
@@ -277,10 +274,9 @@ class ServerReader final : public ReaderInterface<R> {
virtual bool Read(R* msg) override {
CallOpBuffer buf;
- bool got_message;
- buf.AddRecvMessage(msg, &got_message);
+ buf.AddRecvMessage(msg);
call_->PerformOps(&buf);
- return call_->cq()->Pluck(&buf) && got_message;
+ return call_->cq()->Pluck(&buf) && buf.got_message;
}
private:
@@ -338,10 +334,9 @@ class ServerReaderWriter final : public WriterInterface<W>,
virtual bool Read(R* msg) override {
CallOpBuffer buf;
- bool got_message;
- buf.AddRecvMessage(msg, &got_message);
+ buf.AddRecvMessage(msg);
call_->PerformOps(&buf);
- return call_->cq()->Pluck(&buf) && got_message;
+ return call_->cq()->Pluck(&buf) && buf.got_message;
}
virtual bool Write(const W& msg) override {
@@ -420,8 +415,7 @@ class ClientAsyncReader final : public ClientAsyncStreamingInterface,
read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
context_->initial_metadata_received_ = true;
}
- bool ignore;
- read_buf_.AddRecvMessage(msg, &ignore);
+ read_buf_.AddRecvMessage(msg);
call_.PerformOps(&read_buf_);
}
@@ -485,8 +479,7 @@ class ClientAsyncWriter final : public ClientAsyncStreamingInterface,
finish_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
context_->initial_metadata_received_ = true;
}
- bool ignore;
- finish_buf_.AddRecvMessage(response_, &ignore);
+ finish_buf_.AddRecvMessage(response_);
finish_buf_.AddClientRecvStatus(&context_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
@@ -494,7 +487,6 @@ class ClientAsyncWriter final : public ClientAsyncStreamingInterface,
private:
ClientContext* context_ = nullptr;
google::protobuf::Message *const response_;
- bool got_message_;
Call call_;
CallOpBuffer init_buf_;
CallOpBuffer meta_buf_;
@@ -532,8 +524,7 @@ class ClientAsyncReaderWriter final : public ClientAsyncStreamingInterface,
read_buf_.AddRecvInitialMetadata(&context_->recv_initial_metadata_);
context_->initial_metadata_received_ = true;
}
- bool ignore;
- read_buf_.AddRecvMessage(msg, &ignore);
+ read_buf_.AddRecvMessage(msg);
call_.PerformOps(&read_buf_);
}
@@ -624,7 +615,7 @@ class ServerAsyncResponseWriter final : public ServerAsyncStreamingInterface {
CallOpBuffer finish_buf_;
};
-template <class R>
+template <class W, class R>
class ServerAsyncReader : public ServerAsyncStreamingInterface,
public AsyncReaderInterface<R> {
public:
@@ -646,18 +637,34 @@ class ServerAsyncReader : public ServerAsyncStreamingInterface,
call_.PerformOps(&read_buf_);
}
- void Finish(const Status& status, void* tag) {
+ void Finish(const W& msg, const Status& status, void* tag) {
finish_buf_.Reset(tag);
if (!ctx_->sent_initial_metadata_) {
finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
ctx_->sent_initial_metadata_ = true;
}
+ // The response is dropped if the status is not OK.
+ if (status.IsOk()) {
+ finish_buf_.AddSendMessage(msg);
+ }
bool cancelled = false;
finish_buf_.AddServerRecvClose(&cancelled);
finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
call_.PerformOps(&finish_buf_);
}
+ void FinishWithError(const Status& status, void* tag) {
+ GPR_ASSERT(!status.IsOk());
+ finish_buf_.Reset(tag);
+ if (!ctx_->sent_initial_metadata_) {
+ finish_buf_.AddSendInitialMetadata(&ctx_->initial_metadata_);
+ ctx_->sent_initial_metadata_ = true;
+ }
+ bool cancelled = false;
+ finish_buf_.AddServerRecvClose(&cancelled);
+ finish_buf_.AddServerSendStatus(&ctx_->trailing_metadata_, status);
+ call_.PerformOps(&finish_buf_);
+ }
private:
void BindCall(Call *call) override { call_ = *call; }
diff --git a/src/compiler/cpp_generator.cc b/src/compiler/cpp_generator.cc
index a34aa4e568..2a9895e43c 100644
--- a/src/compiler/cpp_generator.cc
+++ b/src/compiler/cpp_generator.cc
@@ -133,7 +133,7 @@ std::string GetHeaderIncludes(const google::protobuf::FileDescriptor *file) {
temp.append("template <class OutMessage> class ClientWriter;\n");
temp.append("template <class InMessage> class ServerReader;\n");
temp.append("template <class OutMessage> class ClientAsyncWriter;\n");
- temp.append("template <class InMessage> class ServerAsyncReader;\n");
+ temp.append("template <class OutMessage, class InMessage> class ServerAsyncReader;\n");
}
if (HasServerOnlyStreaming(file)) {
temp.append("template <class InMessage> class ClientReader;\n");
@@ -267,7 +267,7 @@ void PrintHeaderServerMethodAsync(
printer->Print(*vars,
"void Request$Method$("
"::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReader< $Request$>* reader, "
+ "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
"::grpc::CompletionQueue* cq, void *tag);\n");
} else if (ServerOnlyStreaming(method)) {
printer->Print(*vars,
@@ -538,7 +538,7 @@ void PrintSourceServerAsyncMethod(
printer->Print(*vars,
"void $Service$::AsyncService::Request$Method$("
"::grpc::ServerContext* context, "
- "::grpc::ServerAsyncReader< $Request$>* reader, "
+ "::grpc::ServerAsyncReader< $Response$, $Request$>* reader, "
"::grpc::CompletionQueue* cq, void* tag) {\n");
printer->Print(
*vars,
diff --git a/src/cpp/client/client_unary_call.cc b/src/cpp/client/client_unary_call.cc
index b6bd81d93f..d68d7a9242 100644
--- a/src/cpp/client/client_unary_call.cc
+++ b/src/cpp/client/client_unary_call.cc
@@ -53,21 +53,18 @@ Status BlockingUnaryCall(ChannelInterface *channel, const RpcMethod &method,
buf.AddSendInitialMetadata(context);
buf.AddSendMessage(request);
buf.AddRecvInitialMetadata(&context->recv_initial_metadata_);
- bool got_message;
- buf.AddRecvMessage(result, &got_message);
+ buf.AddRecvMessage(result);
buf.AddClientSendClose();
buf.AddClientRecvStatus(&context->trailing_metadata_, &status);
call.PerformOps(&buf);
- GPR_ASSERT(cq.Pluck(&buf) && (got_message || !status.IsOk()));
+ GPR_ASSERT(cq.Pluck(&buf) && (buf.got_message || !status.IsOk()));
return status;
}
class ClientAsyncRequest final : public CallOpBuffer {
public:
- bool got_message = false;
void FinalizeResult(void** tag, bool* status) override {
CallOpBuffer::FinalizeResult(tag, status);
- *status &= got_message;
delete this;
}
};
@@ -83,7 +80,7 @@ void AsyncUnaryCall(ChannelInterface *channel, const RpcMethod &method,
buf->AddSendInitialMetadata(context);
buf->AddSendMessage(request);
buf->AddRecvInitialMetadata(&context->recv_initial_metadata_);
- buf->AddRecvMessage(result, &buf->got_message);
+ buf->AddRecvMessage(result);
buf->AddClientSendClose();
buf->AddClientRecvStatus(&context->trailing_metadata_, status);
call.PerformOps(buf);
diff --git a/src/cpp/common/call.cc b/src/cpp/common/call.cc
index d706ec45e5..fe8859de94 100644
--- a/src/cpp/common/call.cc
+++ b/src/cpp/common/call.cc
@@ -57,7 +57,7 @@ void CallOpBuffer::Reset(void* next_return_tag) {
}
recv_message_ = nullptr;
- got_message_ = nullptr;
+ got_message = false;
if (recv_message_buf_) {
grpc_byte_buffer_destroy(recv_message_buf_);
recv_message_buf_ = nullptr;
@@ -142,9 +142,8 @@ void CallOpBuffer::AddSendMessage(const google::protobuf::Message& message) {
send_message_ = &message;
}
-void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message, bool* got_message) {
+void CallOpBuffer::AddRecvMessage(google::protobuf::Message *message) {
recv_message_ = message;
- got_message_ = got_message;
}
void CallOpBuffer::AddClientSendClose() {
@@ -256,12 +255,14 @@ void CallOpBuffer::FinalizeResult(void **tag, bool *status) {
// Parse received message if any.
if (recv_message_) {
if (recv_message_buf_) {
- *got_message_ = true;
+ got_message = true;
*status = DeserializeProto(recv_message_buf_, recv_message_);
grpc_byte_buffer_destroy(recv_message_buf_);
recv_message_buf_ = nullptr;
} else {
- *got_message_ = false;
+ // Read failed
+ got_message = false;
+ *status = false;
}
}
// Parse received status.
diff --git a/test/cpp/end2end/async_end2end_test.cc b/test/cpp/end2end/async_end2end_test.cc
index 52fb80e8db..b85aabf09e 100644
--- a/test/cpp/end2end/async_end2end_test.cc
+++ b/test/cpp/end2end/async_end2end_test.cc
@@ -64,9 +64,21 @@ namespace testing {
namespace {
+void* tag(int i) {
+ return (void*)(gpr_intptr)i;
+}
+
+void verify_ok(CompletionQueue* cq, int i, bool expect_ok) {
+ bool ok;
+ void* got_tag;
+ EXPECT_TRUE(cq->Next(&got_tag, &ok));
+ EXPECT_EQ(expect_ok, ok);
+ EXPECT_EQ(tag(i), got_tag);
+}
+
class End2endTest : public ::testing::Test {
protected:
- End2endTest() : service_(&cq_) {}
+ End2endTest() : service_(&srv_cq_) {}
void SetUp() override {
int port = grpc_pick_unused_port_or_die();
@@ -86,20 +98,30 @@ class End2endTest : public ::testing::Test {
stub_.reset(grpc::cpp::test::util::TestService::NewStub(channel));
}
- CompletionQueue cq_;
+ void server_ok(int i) {
+ verify_ok(&srv_cq_, i, true);
+ }
+ void client_ok(int i) {
+ verify_ok(&cli_cq_, i , true);
+ }
+ void server_fail(int i) {
+ verify_ok(&srv_cq_, i, false);
+ }
+ void client_fail(int i) {
+ verify_ok(&cli_cq_, i, false);
+ }
+
+ CompletionQueue cli_cq_;
+ CompletionQueue srv_cq_;
std::unique_ptr<grpc::cpp::test::util::TestService::Stub> stub_;
std::unique_ptr<Server> server_;
grpc::cpp::test::util::TestService::AsyncService service_;
std::ostringstream server_address_;
};
-void* tag(int i) {
- return (void*)(gpr_intptr)i;
-}
-
TEST_F(End2endTest, SimpleRpc) {
ResetStub();
-
+
EchoRequest send_request;
EchoRequest recv_request;
EchoResponse send_response;
@@ -110,34 +132,128 @@ TEST_F(End2endTest, SimpleRpc) {
grpc::ServerAsyncResponseWriter<EchoResponse> response_writer(&srv_ctx);
send_request.set_message("Hello");
- stub_->Echo(&cli_ctx, send_request, &recv_response, &recv_status, &cq_, tag(1));
+ stub_->Echo(
+ &cli_ctx, send_request, &recv_response, &recv_status, &cli_cq_, tag(1));
- service_.RequestEcho(&srv_ctx, &recv_request, &response_writer, &cq_, tag(2));
+ service_.RequestEcho(
+ &srv_ctx, &recv_request, &response_writer, &srv_cq_, tag(2));
- void *got_tag;
- bool ok;
- EXPECT_TRUE(cq_.Next(&got_tag, &ok));
- EXPECT_TRUE(ok);
- EXPECT_EQ(got_tag, tag(2));
- EXPECT_EQ(recv_request.message(), "Hello");
+ server_ok(2);
+ EXPECT_EQ(send_request.message(), recv_request.message());
send_response.set_message(recv_request.message());
response_writer.Finish(send_response, Status::OK, tag(3));
- EXPECT_TRUE(cq_.Next(&got_tag, &ok));
- EXPECT_TRUE(ok);
- if (got_tag == tag(3)) {
- EXPECT_TRUE(cq_.Next(&got_tag, &ok));
- EXPECT_TRUE(ok);
- EXPECT_EQ(got_tag, tag(1));
- } else {
- EXPECT_EQ(got_tag, tag(1));
- EXPECT_TRUE(cq_.Next(&got_tag, &ok));
- EXPECT_TRUE(ok);
- EXPECT_EQ(got_tag, tag(3));
- }
+ server_ok(3);
+
+ client_ok(1);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.IsOk());
+}
+
+TEST_F(End2endTest, SimpleClientStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReader<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ ClientAsyncWriter<EchoRequest>* cli_stream =
+ stub_->RequestStream(&cli_ctx, &recv_response, &cli_cq_, tag(1));
+
+ service_.RequestRequestStream(
+ &srv_ctx, &srv_stream, &srv_cq_, tag(2));
+
+ server_ok(2);
+ client_ok(1);
+
+ cli_stream->Write(send_request, tag(3));
+ client_ok(3);
+
+ srv_stream.Read(&recv_request, tag(4));
+ server_ok(4);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ cli_stream->Write(send_request, tag(5));
+ client_ok(5);
+
+ srv_stream.Read(&recv_request, tag(6));
+ server_ok(6);
+
+ EXPECT_EQ(send_request.message(), recv_request.message());
+ cli_stream->WritesDone(tag(7));
+ client_ok(7);
+
+ srv_stream.Read(&recv_request, tag(8));
+ server_fail(8);
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Finish(send_response, Status::OK, tag(9));
+ server_ok(9);
+
+ cli_stream->Finish(&recv_status, tag(10));
+ client_ok(10);
+
+ EXPECT_EQ(send_response.message(), recv_response.message());
+ EXPECT_TRUE(recv_status.IsOk());
+}
+
+TEST_F(End2endTest, SimpleBidiStreaming) {
+ ResetStub();
+
+ EchoRequest send_request;
+ EchoRequest recv_request;
+ EchoResponse send_response;
+ EchoResponse recv_response;
+ Status recv_status;
+ ClientContext cli_ctx;
+ ServerContext srv_ctx;
+ ServerAsyncReaderWriter<EchoResponse, EchoRequest> srv_stream(&srv_ctx);
+
+ send_request.set_message("Hello");
+ ClientAsyncReaderWriter<EchoRequest, EchoResponse>* cli_stream =
+ stub_->BidiStream(&cli_ctx, &cli_cq_, tag(1));
+
+ service_.RequestBidiStream(
+ &srv_ctx, &srv_stream, &srv_cq_, tag(2));
+
+ server_ok(2);
+ client_ok(1);
+
+ cli_stream->Write(send_request, tag(3));
+ client_ok(3);
+
+ srv_stream.Read(&recv_request, tag(4));
+ server_ok(4);
+ EXPECT_EQ(send_request.message(), recv_request.message());
+
+ send_response.set_message(recv_request.message());
+ srv_stream.Write(send_response, tag(5));
+ server_ok(5);
+
+ cli_stream->Read(&recv_response, tag(6));
+ client_ok(6);
+ EXPECT_EQ(send_response.message(), recv_response.message());
+
+ cli_stream->WritesDone(tag(7));
+ client_ok(7);
+
+ srv_stream.Read(&recv_request, tag(8));
+ server_fail(8);
+
+ srv_stream.Finish(Status::OK, tag(9));
+ server_ok(9);
+
+ cli_stream->Finish(&recv_status, tag(10));
+ client_ok(10);
- EXPECT_EQ(recv_response.message(), "Hello");
EXPECT_TRUE(recv_status.IsOk());
}