diff options
Diffstat (limited to 'src/cpp')
-rw-r--r-- | src/cpp/client/channel.cc | 135 | ||||
-rw-r--r-- | src/cpp/client/client_context.cc | 8 | ||||
-rw-r--r-- | src/cpp/server/async_server_context.cc | 15 | ||||
-rw-r--r-- | src/cpp/server/rpc_service_method.h | 90 | ||||
-rw-r--r-- | src/cpp/server/server_credentials.cc | 62 | ||||
-rw-r--r-- | src/cpp/server/server_rpc_handler.cc | 86 | ||||
-rw-r--r-- | src/cpp/stream/stream_context.cc | 252 | ||||
-rw-r--r-- | src/cpp/stream/stream_context.h | 50 |
8 files changed, 362 insertions, 336 deletions
diff --git a/src/cpp/client/channel.cc b/src/cpp/client/channel.cc index 7a7529104f..3792869d83 100644 --- a/src/cpp/client/channel.cc +++ b/src/cpp/client/channel.cc @@ -59,67 +59,19 @@ Channel::Channel(const grpc::string& target) : target_(target) { Channel::~Channel() { grpc_channel_destroy(c_channel_); } namespace { -// Poll one event from the compeletion queue. Return false when an error -// occured or the polled type is not expected. If a finished event has been -// polled, set finished and set status if it has not been set. -bool NextEvent(grpc_completion_queue* cq, grpc_completion_type expected_type, - bool* finished, bool* status_set, Status* status, - google::protobuf::Message* result) { - // We rely on the c layer to enforce deadline and thus do not use deadline - // here. - grpc_event* ev = grpc_completion_queue_next(cq, gpr_inf_future); - if (!ev) { - return false; - } - bool ret = ev->type == expected_type; - switch (ev->type) { - case GRPC_INVOKE_ACCEPTED: - ret = ret && (ev->data.invoke_accepted == GRPC_OP_OK); - break; - case GRPC_READ: - ret = ret && (ev->data.read != nullptr); - if (ret && !DeserializeProto(ev->data.read, result)) { - *status_set = true; - *status = - Status(StatusCode::DATA_LOSS, "Failed to parse response proto."); - ret = false; - } - break; - case GRPC_WRITE_ACCEPTED: - ret = ret && (ev->data.write_accepted == GRPC_OP_OK); - break; - case GRPC_FINISH_ACCEPTED: - ret = ret && (ev->data.finish_accepted == GRPC_OP_OK); - break; - case GRPC_CLIENT_METADATA_READ: - break; - case GRPC_FINISHED: - *finished = true; - if (!*status_set) { - *status_set = true; - StatusCode error_code = static_cast<StatusCode>(ev->data.finished.code); - grpc::string details( - ev->data.finished.details ? ev->data.finished.details : ""); - *status = Status(error_code, details); - } - break; - default: - gpr_log(GPR_ERROR, "Dropping unhandled event with type %d", ev->type); - break; - } - grpc_event_finish(ev); - return ret; -} - -// If finished is not true, get final status by polling until a finished -// event is obtained. -void GetFinalStatus(grpc_completion_queue* cq, bool status_set, bool finished, +// Pluck the finished event and set to status when it is not nullptr. +void GetFinalStatus(grpc_completion_queue* cq, void* finished_tag, Status* status) { - while (!finished) { - NextEvent(cq, GRPC_FINISHED, &finished, &status_set, status, nullptr); + grpc_event* ev = + grpc_completion_queue_pluck(cq, finished_tag, gpr_inf_future); + if (status) { + StatusCode error_code = static_cast<StatusCode>(ev->data.finished.code); + grpc::string details(ev->data.finished.details ? ev->data.finished.details + : ""); + *status = Status(error_code, details); } + grpc_event_finish(ev); } - } // namespace // TODO(yangg) more error handling @@ -128,8 +80,6 @@ Status Channel::StartBlockingRpc(const RpcMethod& method, const google::protobuf::Message& request, google::protobuf::Message* result) { Status status; - bool status_set = false; - bool finished = false; gpr_timespec absolute_deadline; AbsoluteDeadlineTimepoint2Timespec(context->absolute_deadline(), &absolute_deadline); @@ -137,59 +87,68 @@ Status Channel::StartBlockingRpc(const RpcMethod& method, // FIXME(yangg) "localhost", absolute_deadline); context->set_call(call); + grpc_event* ev; + void* finished_tag = reinterpret_cast<char*>(call); + void* invoke_tag = reinterpret_cast<char*>(call) + 1; + void* metadata_read_tag = reinterpret_cast<char*>(call) + 2; + void* write_tag = reinterpret_cast<char*>(call) + 3; + void* halfclose_tag = reinterpret_cast<char*>(call) + 4; + void* read_tag = reinterpret_cast<char*>(call) + 5; + grpc_completion_queue* cq = grpc_completion_queue_create(); context->set_cq(cq); // add_metadata from context // // invoke - GPR_ASSERT(grpc_call_start_invoke(call, cq, call, call, call, + GPR_ASSERT(grpc_call_start_invoke(call, cq, invoke_tag, metadata_read_tag, + finished_tag, GRPC_WRITE_BUFFER_HINT) == GRPC_CALL_OK); - if (!NextEvent(cq, GRPC_INVOKE_ACCEPTED, &status_set, &finished, &status, - nullptr)) { - GetFinalStatus(cq, finished, status_set, &status); - return status; - } + ev = grpc_completion_queue_pluck(cq, invoke_tag, gpr_inf_future); + grpc_event_finish(ev); // write request grpc_byte_buffer* write_buffer = nullptr; bool success = SerializeProto(request, &write_buffer); if (!success) { grpc_call_cancel(call); - status_set = true; status = Status(StatusCode::DATA_LOSS, "Failed to serialize request proto."); - GetFinalStatus(cq, finished, status_set, &status); + GetFinalStatus(cq, finished_tag, nullptr); return status; } - GPR_ASSERT(grpc_call_start_write(call, write_buffer, call, + GPR_ASSERT(grpc_call_start_write(call, write_buffer, write_tag, GRPC_WRITE_BUFFER_HINT) == GRPC_CALL_OK); grpc_byte_buffer_destroy(write_buffer); - if (!NextEvent(cq, GRPC_WRITE_ACCEPTED, &finished, &status_set, &status, - nullptr)) { - GetFinalStatus(cq, finished, status_set, &status); + ev = grpc_completion_queue_pluck(cq, write_tag, gpr_inf_future); + + success = ev->data.write_accepted == GRPC_OP_OK; + grpc_event_finish(ev); + if (!success) { + GetFinalStatus(cq, finished_tag, &status); return status; } // writes done - GPR_ASSERT(grpc_call_writes_done(call, call) == GRPC_CALL_OK); - if (!NextEvent(cq, GRPC_FINISH_ACCEPTED, &finished, &status_set, &status, - nullptr)) { - GetFinalStatus(cq, finished, status_set, &status); - return status; - } + GPR_ASSERT(grpc_call_writes_done(call, halfclose_tag) == GRPC_CALL_OK); + ev = grpc_completion_queue_pluck(cq, halfclose_tag, gpr_inf_future); + grpc_event_finish(ev); // start read metadata // - if (!NextEvent(cq, GRPC_CLIENT_METADATA_READ, &finished, &status_set, &status, - nullptr)) { - GetFinalStatus(cq, finished, status_set, &status); - return status; - } + ev = grpc_completion_queue_pluck(cq, metadata_read_tag, gpr_inf_future); + grpc_event_finish(ev); // start read - GPR_ASSERT(grpc_call_start_read(call, call) == GRPC_CALL_OK); - if (!NextEvent(cq, GRPC_READ, &finished, &status_set, &status, result)) { - GetFinalStatus(cq, finished, status_set, &status); - return status; + GPR_ASSERT(grpc_call_start_read(call, read_tag) == GRPC_CALL_OK); + ev = grpc_completion_queue_pluck(cq, read_tag, gpr_inf_future); + if (ev->data.read) { + if (!DeserializeProto(ev->data.read, result)) { + grpc_event_finish(ev); + status = Status(StatusCode::DATA_LOSS, "Failed to parse response proto."); + GetFinalStatus(cq, finished_tag, nullptr); + return status; + } } + grpc_event_finish(ev); + // wait status - GetFinalStatus(cq, finished, status_set, &status); + GetFinalStatus(cq, finished_tag, &status); return status; } diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index 78774a7f12..58a8ad252b 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -50,6 +50,14 @@ ClientContext::~ClientContext() { } if (cq_) { grpc_completion_queue_shutdown(cq_); + // Drain cq_. + grpc_event* ev; + grpc_completion_type t; + do { + ev = grpc_completion_queue_next(cq_, gpr_inf_future); + t = ev->type; + grpc_event_finish(ev); + } while (t != GRPC_QUEUE_SHUTDOWN); grpc_completion_queue_destroy(cq_); } } diff --git a/src/cpp/server/async_server_context.cc b/src/cpp/server/async_server_context.cc index b231f4b0cf..0a9c07f403 100644 --- a/src/cpp/server/async_server_context.cc +++ b/src/cpp/server/async_server_context.cc @@ -75,18 +75,11 @@ bool AsyncServerContext::StartWrite(const google::protobuf::Message& response, return err == GRPC_CALL_OK; } -namespace { -grpc_status TranslateStatus(const Status& status) { - grpc_status c_status; - // TODO(yangg) - c_status.code = GRPC_STATUS_OK; - c_status.details = nullptr; - return c_status; -} -} // namespace - bool AsyncServerContext::StartWriteStatus(const Status& status) { - grpc_status c_status = TranslateStatus(status); + grpc_status c_status = {static_cast<grpc_status_code>(status.code()), + status.details().empty() + ? nullptr + : const_cast<char*>(status.details().c_str())}; grpc_call_error err = grpc_call_start_write_status(call_, c_status, this); return err == GRPC_CALL_OK; } diff --git a/src/cpp/server/rpc_service_method.h b/src/cpp/server/rpc_service_method.h index ac2badda71..425545fd22 100644 --- a/src/cpp/server/rpc_service_method.h +++ b/src/cpp/server/rpc_service_method.h @@ -42,8 +42,10 @@ #include "src/cpp/rpc_method.h" #include <google/protobuf/message.h> #include <grpc++/status.h> +#include <grpc++/stream.h> namespace grpc { +class StreamContextInterface; // TODO(rocking): we might need to split this file into multiple ones. @@ -53,23 +55,27 @@ class MethodHandler { virtual ~MethodHandler() {} struct HandlerParameter { HandlerParameter(const google::protobuf::Message* req, google::protobuf::Message* resp) - : request(req), response(resp) {} + : request(req), response(resp), stream_context(nullptr) {} + HandlerParameter(const google::protobuf::Message* req, google::protobuf::Message* resp, + StreamContextInterface* context) + : request(req), response(resp), stream_context(context) {} const google::protobuf::Message* request; google::protobuf::Message* response; + StreamContextInterface* stream_context; }; - virtual ::grpc::Status RunHandler(const HandlerParameter& param) = 0; + virtual Status RunHandler(const HandlerParameter& param) = 0; }; // A wrapper class of an application provided rpc method handler. template <class ServiceType, class RequestType, class ResponseType> class RpcMethodHandler : public MethodHandler { public: - RpcMethodHandler(std::function<::grpc::Status( - ServiceType*, const RequestType*, ResponseType*)> func, + RpcMethodHandler(std::function<Status(ServiceType*, const RequestType*, + ResponseType*)> func, ServiceType* service) : func_(func), service_(service) {} - ::grpc::Status RunHandler(const HandlerParameter& param) final { + Status RunHandler(const HandlerParameter& param) final { // Invoke application function, cast proto messages to their actual types. return func_(service_, dynamic_cast<const RequestType*>(param.request), dynamic_cast<ResponseType*>(param.response)); @@ -77,20 +83,84 @@ class RpcMethodHandler : public MethodHandler { private: // Application provided rpc handler function. - std::function<::grpc::Status(ServiceType*, const RequestType*, ResponseType*)> - func_; + std::function<Status(ServiceType*, const RequestType*, ResponseType*)> func_; // The class the above handler function lives in. ServiceType* service_; }; +// A wrapper class of an application provided client streaming handler. +template <class ServiceType, class RequestType, class ResponseType> +class ClientStreamingHandler : public MethodHandler { + public: + ClientStreamingHandler( + std::function<Status(ServiceType*, ServerReader<RequestType>*, + ResponseType*)> func, + ServiceType* service) + : func_(func), service_(service) {} + + Status RunHandler(const HandlerParameter& param) final { + ServerReader<RequestType> reader(param.stream_context); + return func_(service_, &reader, + dynamic_cast<ResponseType*>(param.response)); + } + + private: + std::function<Status(ServiceType*, ServerReader<RequestType>*, ResponseType*)> + func_; + ServiceType* service_; +}; + +// A wrapper class of an application provided server streaming handler. +template <class ServiceType, class RequestType, class ResponseType> +class ServerStreamingHandler : public MethodHandler { + public: + ServerStreamingHandler( + std::function<Status(ServiceType*, const RequestType*, + ServerWriter<ResponseType>*)> func, + ServiceType* service) + : func_(func), service_(service) {} + + Status RunHandler(const HandlerParameter& param) final { + ServerWriter<ResponseType> writer(param.stream_context); + return func_(service_, dynamic_cast<const RequestType*>(param.request), + &writer); + } + + private: + std::function<Status(ServiceType*, const RequestType*, + ServerWriter<ResponseType>*)> func_; + ServiceType* service_; +}; + +// A wrapper class of an application provided bidi-streaming handler. +template <class ServiceType, class RequestType, class ResponseType> +class BidiStreamingHandler : public MethodHandler { + public: + BidiStreamingHandler( + std::function<Status( + ServiceType*, ServerReaderWriter<ResponseType, RequestType>*)> func, + ServiceType* service) + : func_(func), service_(service) {} + + Status RunHandler(const HandlerParameter& param) final { + ServerReaderWriter<ResponseType, RequestType> stream(param.stream_context); + return func_(service_, &stream); + } + + private: + std::function<Status(ServiceType*, + ServerReaderWriter<ResponseType, RequestType>*)> func_; + ServiceType* service_; +}; + // Server side rpc method class class RpcServiceMethod : public RpcMethod { public: // Takes ownership of the handler and two prototype objects. - RpcServiceMethod(const char* name, MethodHandler* handler, - google::protobuf::Message* request_prototype, + RpcServiceMethod(const char* name, RpcMethod::RpcType type, + MethodHandler* handler, google::protobuf::Message* request_prototype, google::protobuf::Message* response_prototype) - : RpcMethod(name), + : RpcMethod(name, type), handler_(handler), request_prototype_(request_prototype), response_prototype_(response_prototype) {} diff --git a/src/cpp/server/server_credentials.cc b/src/cpp/server/server_credentials.cc new file mode 100644 index 0000000000..d23a09f3c1 --- /dev/null +++ b/src/cpp/server/server_credentials.cc @@ -0,0 +1,62 @@ +/* + * + * Copyright 2014, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + + +#include <grpc/grpc_security.h> + +#include <grpc++/server_credentials.h> + +namespace grpc { + +ServerCredentials::ServerCredentials(grpc_server_credentials* c_creds) + : creds_(c_creds) {} + +ServerCredentials::~ServerCredentials() { + grpc_server_credentials_release(creds_); +} + +grpc_server_credentials* ServerCredentials::GetRawCreds() { return creds_; } + +std::shared_ptr<ServerCredentials> ServerCredentialsFactory::SslCredentials( + const SslServerCredentialsOptions& options) { + grpc_server_credentials* c_creds = grpc_ssl_server_credentials_create( + reinterpret_cast<const unsigned char*>(options.pem_root_certs.c_str()), + options.pem_root_certs.size(), + reinterpret_cast<const unsigned char*>(options.pem_private_key.c_str()), + options.pem_private_key.size(), + reinterpret_cast<const unsigned char*>(options.pem_cert_chain.c_str()), + options.pem_cert_chain.size()); + return std::shared_ptr<ServerCredentials>(new ServerCredentials(c_creds)); +} + +} // namespace grpc diff --git a/src/cpp/server/server_rpc_handler.cc b/src/cpp/server/server_rpc_handler.cc index 2d5a081deb..4c8d0cd04e 100644 --- a/src/cpp/server/server_rpc_handler.cc +++ b/src/cpp/server/server_rpc_handler.cc @@ -35,20 +35,16 @@ #include <grpc/support/log.h> #include "src/cpp/server/rpc_service_method.h" +#include "src/cpp/stream/stream_context.h" #include <grpc++/async_server_context.h> namespace grpc { ServerRpcHandler::ServerRpcHandler(AsyncServerContext* server_context, RpcServiceMethod* method) - : server_context_(server_context), - method_(method) { -} + : server_context_(server_context), method_(method) {} void ServerRpcHandler::StartRpc() { - // Start the rpc on this dedicated completion queue. - server_context_->Accept(cq_.cq()); - if (method_ == nullptr) { // Method not supported, finish the rpc with error. // TODO(rocking): do we need to call read to consume the request? @@ -56,30 +52,54 @@ void ServerRpcHandler::StartRpc() { return; } - // Allocate request and response. - std::unique_ptr<google::protobuf::Message> request(method_->AllocateRequestProto()); - std::unique_ptr<google::protobuf::Message> response(method_->AllocateResponseProto()); - - // Read request - server_context_->StartRead(request.get()); - auto type = WaitForNextEvent(); - GPR_ASSERT(type == CompletionQueue::SERVER_READ_OK); - - // Run the application's rpc handler - MethodHandler* handler = method_->handler(); - Status status = handler->RunHandler( - MethodHandler::HandlerParameter(request.get(), response.get())); - - if (status.IsOk()) { - // Send the response if we get an ok status. - server_context_->StartWrite(*response, 0); - type = WaitForNextEvent(); - if (type != CompletionQueue::SERVER_WRITE_OK) { - status = Status(StatusCode::INTERNAL, "Error writing response."); + if (method_->method_type() == RpcMethod::NORMAL_RPC) { + // Start the rpc on this dedicated completion queue. + server_context_->Accept(cq_.cq()); + + // Allocate request and response. + std::unique_ptr<google::protobuf::Message> request(method_->AllocateRequestProto()); + std::unique_ptr<google::protobuf::Message> response(method_->AllocateResponseProto()); + + // Read request + server_context_->StartRead(request.get()); + auto type = WaitForNextEvent(); + GPR_ASSERT(type == CompletionQueue::SERVER_READ_OK); + + // Run the application's rpc handler + MethodHandler* handler = method_->handler(); + Status status = handler->RunHandler( + MethodHandler::HandlerParameter(request.get(), response.get())); + + if (status.IsOk()) { + // Send the response if we get an ok status. + server_context_->StartWrite(*response, 0); + type = WaitForNextEvent(); + if (type != CompletionQueue::SERVER_WRITE_OK) { + status = Status(StatusCode::INTERNAL, "Error writing response."); + } } - } - FinishRpc(status); + FinishRpc(status); + } else { + // Allocate request and response. + // TODO(yangg) maybe not allocate both when not needed? + std::unique_ptr<google::protobuf::Message> request(method_->AllocateRequestProto()); + std::unique_ptr<google::protobuf::Message> response(method_->AllocateResponseProto()); + + StreamContext stream_context(*method_, server_context_->call(), cq_.cq(), + request.get(), response.get()); + + // Run the application's rpc handler + MethodHandler* handler = method_->handler(); + Status status = handler->RunHandler(MethodHandler::HandlerParameter( + request.get(), response.get(), &stream_context)); + if (status.IsOk() && + method_->method_type() == RpcMethod::CLIENT_STREAMING) { + stream_context.Write(response.get(), false); + } + // TODO(yangg) Do we need to consider the status in stream_context? + FinishRpc(status); + } } CompletionQueue::CompletionType ServerRpcHandler::WaitForNextEvent() { @@ -94,11 +114,15 @@ CompletionQueue::CompletionType ServerRpcHandler::WaitForNextEvent() { void ServerRpcHandler::FinishRpc(const Status& status) { server_context_->StartWriteStatus(status); - CompletionQueue::CompletionType type = WaitForNextEvent(); - // TODO(rocking): do we care about this return type? + CompletionQueue::CompletionType type; + // HALFCLOSE_OK and RPC_END events come in either order. + type = WaitForNextEvent(); + GPR_ASSERT(type == CompletionQueue::HALFCLOSE_OK || + type == CompletionQueue::RPC_END); type = WaitForNextEvent(); - GPR_ASSERT(type == CompletionQueue::RPC_END); + GPR_ASSERT(type == CompletionQueue::HALFCLOSE_OK || + type == CompletionQueue::RPC_END); cq_.Shutdown(); type = WaitForNextEvent(); diff --git a/src/cpp/stream/stream_context.cc b/src/cpp/stream/stream_context.cc index 07e771f7e1..706e90c481 100644 --- a/src/cpp/stream/stream_context.cc +++ b/src/cpp/stream/stream_context.cc @@ -33,7 +33,6 @@ #include "src/cpp/stream/stream_context.h" -#include <grpc/grpc.h> #include <grpc/support/log.h> #include "src/cpp/rpc_method.h" #include "src/cpp/proto/proto_utils.h" @@ -50,227 +49,146 @@ StreamContext::StreamContext(const RpcMethod& method, ClientContext* context, google::protobuf::Message* result) : is_client_(true), method_(&method), - context_(context), - request_(request), + call_(context->call()), + cq_(context->cq()), + request_(const_cast<google::protobuf::Message*>(request)), result_(result), - invoke_ev_(nullptr), - read_ev_(nullptr), - write_ev_(nullptr), - reading_(false), - writing_(false), - got_read_(false), - got_write_(false), peer_halfclosed_(false), - self_halfclosed_(false), - stream_finished_(false), - waiting_(false) { + self_halfclosed_(false) { GPR_ASSERT(method_->method_type() != RpcMethod::RpcType::NORMAL_RPC); } -StreamContext::~StreamContext() { cq_poller_.join(); } - -void StreamContext::PollingLoop() { - grpc_event* ev = nullptr; - gpr_timespec absolute_deadline; - AbsoluteDeadlineTimepoint2Timespec(context_->absolute_deadline(), - &absolute_deadline); - std::condition_variable* cv_to_notify = nullptr; - std::unique_lock<std::mutex> lock(mu_, std::defer_lock); - while (1) { - cv_to_notify = nullptr; - lock.lock(); - if (stream_finished_ && !reading_ && !writing_) { - return; - } - lock.unlock(); - ev = grpc_completion_queue_next(context_->cq(), absolute_deadline); - lock.lock(); - if (!ev) { - stream_finished_ = true; - final_status_ = Status(StatusCode::DEADLINE_EXCEEDED); - std::condition_variable* cvs[3] = {reading_ ? &read_cv_ : nullptr, - writing_ ? &write_cv_ : nullptr, - waiting_ ? &finish_cv_ : nullptr}; - got_read_ = reading_; - got_write_ = writing_; - read_ev_ = nullptr; - write_ev_ = nullptr; - lock.unlock(); - for (int i = 0; i < 3; i++) { - if (cvs[i]) cvs[i]->notify_one(); - } - break; - } - switch (ev->type) { - case GRPC_READ: - GPR_ASSERT(reading_); - got_read_ = true; - read_ev_ = ev; - cv_to_notify = &read_cv_; - reading_ = false; - break; - case GRPC_FINISH_ACCEPTED: - case GRPC_WRITE_ACCEPTED: - got_write_ = true; - write_ev_ = ev; - cv_to_notify = &write_cv_; - writing_ = false; - break; - case GRPC_FINISHED: { - grpc::string error_details( - ev->data.finished.details ? ev->data.finished.details : ""); - final_status_ = Status(static_cast<StatusCode>(ev->data.finished.code), - error_details); - grpc_event_finish(ev); - stream_finished_ = true; - if (waiting_) { - cv_to_notify = &finish_cv_; - } - break; - } - case GRPC_INVOKE_ACCEPTED: - invoke_ev_ = ev; - cv_to_notify = &invoke_cv_; - break; - case GRPC_CLIENT_METADATA_READ: - grpc_event_finish(ev); - break; - default: - grpc_event_finish(ev); - // not handling other types now - gpr_log(GPR_ERROR, "unknown event type"); - abort(); - } - lock.unlock(); - if (cv_to_notify) { - cv_to_notify->notify_one(); - } - } +// Server only ctor +StreamContext::StreamContext(const RpcMethod& method, grpc_call* call, + grpc_completion_queue* cq, + google::protobuf::Message* request, google::protobuf::Message* result) + : is_client_(false), + method_(&method), + call_(call), + cq_(cq), + request_(request), + result_(result), + peer_halfclosed_(false), + self_halfclosed_(false) { + GPR_ASSERT(method_->method_type() != RpcMethod::RpcType::NORMAL_RPC); } -void StreamContext::Start(bool buffered) { - // TODO(yangg) handle metadata send path - int flag = buffered ? GRPC_WRITE_BUFFER_HINT : 0; - grpc_call_error error = grpc_call_start_invoke( - context_->call(), context_->cq(), this, this, this, flag); - GPR_ASSERT(GRPC_CALL_OK == error); - // kicks off the poller thread - cq_poller_ = std::thread(&StreamContext::PollingLoop, this); - std::unique_lock<std::mutex> lock(mu_); - while (!invoke_ev_) { - invoke_cv_.wait(lock); - } - lock.unlock(); - GPR_ASSERT(invoke_ev_->data.invoke_accepted == GRPC_OP_OK); - grpc_event_finish(invoke_ev_); -} +StreamContext::~StreamContext() {} -namespace { -// Wait for got_event with event_cv protected by mu, return event. -grpc_event* WaitForEvent(bool* got_event, std::condition_variable* event_cv, - std::mutex* mu, grpc_event** event) { - std::unique_lock<std::mutex> lock(*mu); - while (*got_event == false) { - event_cv->wait(lock); +void StreamContext::Start(bool buffered) { + if (is_client_) { + // TODO(yangg) handle metadata send path + int flag = buffered ? GRPC_WRITE_BUFFER_HINT : 0; + grpc_call_error error = grpc_call_start_invoke(call(), cq(), invoke_tag(), + client_metadata_read_tag(), + finished_tag(), flag); + GPR_ASSERT(GRPC_CALL_OK == error); + grpc_event* invoke_ev = + grpc_completion_queue_pluck(cq(), invoke_tag(), gpr_inf_future); + grpc_event_finish(invoke_ev); + } else { + // TODO(yangg) metadata needs to be added before accept + // TODO(yangg) correctly set flag to accept + grpc_call_error error = grpc_call_accept(call(), cq(), finished_tag(), 0); + GPR_ASSERT(GRPC_CALL_OK == error); } - *got_event = false; - return *event; } -} // namespace bool StreamContext::Read(google::protobuf::Message* msg) { - std::unique_lock<std::mutex> lock(mu_); - if (stream_finished_) { - peer_halfclosed_ = true; - return false; - } - reading_ = true; - lock.unlock(); - - grpc_call_error err = grpc_call_start_read(context_->call(), this); + // TODO(yangg) check peer_halfclosed_ here for possible early return. + grpc_call_error err = grpc_call_start_read(call(), read_tag()); GPR_ASSERT(err == GRPC_CALL_OK); - - grpc_event* ev = WaitForEvent(&got_read_, &read_cv_, &mu_, &read_ev_); - if (!ev) { - return false; - } - GPR_ASSERT(ev->type == GRPC_READ); + grpc_event* read_ev = + grpc_completion_queue_pluck(cq(), read_tag(), gpr_inf_future); + GPR_ASSERT(read_ev->type == GRPC_READ); bool ret = true; - if (ev->data.read) { - if (!DeserializeProto(ev->data.read, msg)) { - ret = false; // parse error - // TODO(yangg) cancel the stream. + if (read_ev->data.read) { + if (!DeserializeProto(read_ev->data.read, msg)) { + ret = false; + FinishStream( + Status(StatusCode::DATA_LOSS, "Failed to parse incoming proto"), + true); } } else { ret = false; peer_halfclosed_ = true; } - grpc_event_finish(ev); + grpc_event_finish(read_ev); return ret; } bool StreamContext::Write(const google::protobuf::Message* msg, bool is_last) { + // TODO(yangg) check self_halfclosed_ for possible early return. bool ret = true; grpc_event* ev = nullptr; - std::unique_lock<std::mutex> lock(mu_); - if (stream_finished_) { - self_halfclosed_ = true; - return false; - } - writing_ = true; - lock.unlock(); - if (msg) { grpc_byte_buffer* out_buf = nullptr; if (!SerializeProto(*msg, &out_buf)) { FinishStream(Status(StatusCode::INVALID_ARGUMENT, - "Failed to serialize request proto"), + "Failed to serialize outgoing proto"), true); return false; } int flag = is_last ? GRPC_WRITE_BUFFER_HINT : 0; grpc_call_error err = - grpc_call_start_write(context_->call(), out_buf, this, flag); + grpc_call_start_write(call(), out_buf, write_tag(), flag); grpc_byte_buffer_destroy(out_buf); GPR_ASSERT(err == GRPC_CALL_OK); - ev = WaitForEvent(&got_write_, &write_cv_, &mu_, &write_ev_); - if (!ev) { - return false; - } + ev = grpc_completion_queue_pluck(cq(), write_tag(), gpr_inf_future); GPR_ASSERT(ev->type == GRPC_WRITE_ACCEPTED); ret = ev->data.write_accepted == GRPC_OP_OK; grpc_event_finish(ev); } - if (is_last) { - grpc_call_error err = grpc_call_writes_done(context_->call(), this); + if (ret && is_last) { + grpc_call_error err = grpc_call_writes_done(call(), halfclose_tag()); GPR_ASSERT(err == GRPC_CALL_OK); - ev = WaitForEvent(&got_write_, &write_cv_, &mu_, &write_ev_); - if (!ev) { - return false; - } + ev = grpc_completion_queue_pluck(cq(), halfclose_tag(), gpr_inf_future); GPR_ASSERT(ev->type == GRPC_FINISH_ACCEPTED); grpc_event_finish(ev); + self_halfclosed_ = true; + } else if (!ret) { // Stream broken + self_halfclosed_ = true; + peer_halfclosed_ = true; } + return ret; } const Status& StreamContext::Wait() { - std::unique_lock<std::mutex> lock(mu_); - // TODO(yangg) if not halfclosed cancel the stream - GPR_ASSERT(self_halfclosed_); - GPR_ASSERT(peer_halfclosed_); - GPR_ASSERT(!waiting_); - waiting_ = true; - while (!stream_finished_) { - finish_cv_.wait(lock); + // TODO(yangg) properly support metadata + grpc_event* metadata_ev = grpc_completion_queue_pluck( + cq(), client_metadata_read_tag(), gpr_inf_future); + grpc_event_finish(metadata_ev); + // TODO(yangg) protect states by a mutex, including other places. + if (!self_halfclosed_ || !peer_halfclosed_) { + FinishStream(Status::Cancelled, true); + } else { + grpc_event* finish_ev = + grpc_completion_queue_pluck(cq(), finished_tag(), gpr_inf_future); + GPR_ASSERT(finish_ev->type == GRPC_FINISHED); + std::string error_details(finish_ev->data.finished.details + ? finish_ev->data.finished.details + : ""); + final_status_ = Status( + static_cast<StatusCode>(finish_ev->data.finished.code), error_details); + grpc_event_finish(finish_ev); } return final_status_; } -void StreamContext::FinishStream(const Status& status, bool send) { return; } +void StreamContext::FinishStream(const Status& status, bool send) { + if (send) { + grpc_call_cancel(call()); + } + grpc_event* finish_ev = + grpc_completion_queue_pluck(cq(), finished_tag(), gpr_inf_future); + GPR_ASSERT(finish_ev->type == GRPC_FINISHED); + grpc_event_finish(finish_ev); + final_status_ = status; +} } // namespace grpc diff --git a/src/cpp/stream/stream_context.h b/src/cpp/stream/stream_context.h index b7f462f323..6c31095042 100644 --- a/src/cpp/stream/stream_context.h +++ b/src/cpp/stream/stream_context.h @@ -34,10 +34,7 @@ #ifndef __GRPCPP_INTERNAL_STREAM_STREAM_CONTEXT_H__ #define __GRPCPP_INTERNAL_STREAM_STREAM_CONTEXT_H__ -#include <condition_variable> -#include <mutex> -#include <thread> - +#include <grpc/grpc.h> #include <grpc++/status.h> #include <grpc++/stream_context_interface.h> @@ -47,8 +44,6 @@ class Message; } } -struct grpc_event; - namespace grpc { class ClientContext; class RpcMethod; @@ -57,6 +52,9 @@ class StreamContext : public StreamContextInterface { public: StreamContext(const RpcMethod& method, ClientContext* context, const google::protobuf::Message* request, google::protobuf::Message* result); + StreamContext(const RpcMethod& method, grpc_call* call, + grpc_completion_queue* cq, google::protobuf::Message* request, + google::protobuf::Message* result); ~StreamContext(); // Start the stream, if there is a final write following immediately, set // buffered so that the messages can be sent in batch. @@ -66,37 +64,31 @@ class StreamContext : public StreamContextInterface { const Status& Wait() override; void FinishStream(const Status& status, bool send) override; - const google::protobuf::Message* request() override { return request_; } + google::protobuf::Message* request() override { return request_; } google::protobuf::Message* response() override { return result_; } private: - void PollingLoop(); - bool BlockingStart(); + // Unique tags for plucking events from the c layer. this pointer is casted + // to char* to create single byte step between tags. It implicitly relies on + // that StreamContext is large enough to contain all the pointers. + void* finished_tag() { return reinterpret_cast<char*>(this); } + void* read_tag() { return reinterpret_cast<char*>(this) + 1; } + void* write_tag() { return reinterpret_cast<char*>(this) + 2; } + void* halfclose_tag() { return reinterpret_cast<char*>(this) + 3; } + void* invoke_tag() { return reinterpret_cast<char*>(this) + 4; } + void* client_metadata_read_tag() { return reinterpret_cast<char*>(this) + 5; } + grpc_call* call() { return call_; } + grpc_completion_queue* cq() { return cq_; } + bool is_client_; const RpcMethod* method_; // not owned - ClientContext* context_; // now owned - const google::protobuf::Message* request_; // not owned - google::protobuf::Message* result_; // not owned + grpc_call* call_; // not owned + grpc_completion_queue* cq_; // not owned + google::protobuf::Message* request_; // first request, not owned + google::protobuf::Message* result_; // last response, not owned - std::thread cq_poller_; - std::mutex mu_; - std::condition_variable invoke_cv_; - std::condition_variable read_cv_; - std::condition_variable write_cv_; - std::condition_variable finish_cv_; - grpc_event* invoke_ev_; - // TODO(yangg) make these two into queues to support concurrent reads and - // writes - grpc_event* read_ev_; - grpc_event* write_ev_; - bool reading_; - bool writing_; - bool got_read_; - bool got_write_; bool peer_halfclosed_; bool self_halfclosed_; - bool stream_finished_; - bool waiting_; Status final_status_; }; |