diff options
Diffstat (limited to 'src')
22 files changed, 653 insertions, 124 deletions
diff --git a/src/core/ext/census/tracing.c b/src/core/ext/census/tracing.c index 543a73c5ad..823c681abf 100644 --- a/src/core/ext/census/tracing.c +++ b/src/core/ext/census/tracing.c @@ -21,7 +21,6 @@ #include <grpc/census.h> #include <grpc/support/alloc.h> #include <grpc/support/log.h> -#include <openssl/rand.h> #include "src/core/ext/census/mlog.h" void trace_start_span(const trace_span_context *span_ctxt, diff --git a/src/core/ext/transport/chttp2/transport/chttp2_transport.c b/src/core/ext/transport/chttp2/transport/chttp2_transport.c index aabe7b4a8e..8976686082 100644 --- a/src/core/ext/transport/chttp2/transport/chttp2_transport.c +++ b/src/core/ext/transport/chttp2/transport/chttp2_transport.c @@ -1788,9 +1788,8 @@ void grpc_chttp2_maybe_complete_recv_trailing_metadata(grpc_exec_ctx *exec_ctx, bool pending_data = s->pending_byte_stream || s->unprocessed_incoming_frames_buffer.length > 0; if (s->stream_compression_recv_enabled && s->read_closed && - s->frame_storage.length > 0 && - s->unprocessed_incoming_frames_buffer.length == 0 && !pending_data && - !s->seen_error && s->recv_trailing_metadata_finished != NULL) { + s->frame_storage.length > 0 && !pending_data && !s->seen_error && + s->recv_trailing_metadata_finished != NULL) { /* Maybe some SYNC_FLUSH data is left in frame_storage. Consume them and * maybe decompress the next 5 bytes in the stream. */ bool end_of_context; @@ -1817,7 +1816,6 @@ void grpc_chttp2_maybe_complete_recv_trailing_metadata(grpc_exec_ctx *exec_ctx, } } if (s->read_closed && s->frame_storage.length == 0 && - s->unprocessed_incoming_frames_buffer.length == 0 && (!pending_data || s->seen_error) && s->recv_trailing_metadata_finished != NULL) { grpc_chttp2_incoming_metadata_buffer_publish( diff --git a/src/core/lib/iomgr/exec_ctx.c b/src/core/lib/iomgr/exec_ctx.c index 833170ceed..41c69add17 100644 --- a/src/core/lib/iomgr/exec_ctx.c +++ b/src/core/lib/iomgr/exec_ctx.c @@ -51,33 +51,6 @@ bool grpc_exec_ctx_has_work(grpc_exec_ctx *exec_ctx) { !grpc_closure_list_empty(exec_ctx->closure_list); } -bool grpc_exec_ctx_flush(grpc_exec_ctx *exec_ctx) { - bool did_something = 0; - GPR_TIMER_BEGIN("grpc_exec_ctx_flush", 0); - for (;;) { - if (!grpc_closure_list_empty(exec_ctx->closure_list)) { - grpc_closure *c = exec_ctx->closure_list.head; - exec_ctx->closure_list.head = exec_ctx->closure_list.tail = NULL; - while (c != NULL) { - grpc_closure *next = c->next_data.next; - grpc_error *error = c->error_data.error; - did_something = true; -#ifndef NDEBUG - c->scheduled = false; -#endif - c->cb(exec_ctx, c->cb_arg, error); - GRPC_ERROR_UNREF(error); - c = next; - } - } else if (!grpc_combiner_continue_exec_ctx(exec_ctx)) { - break; - } - } - GPR_ASSERT(exec_ctx->active_combiner == NULL); - GPR_TIMER_END("grpc_exec_ctx_flush", 0); - return did_something; -} - void grpc_exec_ctx_finish(grpc_exec_ctx *exec_ctx) { exec_ctx->flags |= GRPC_EXEC_CTX_FLAG_IS_FINISHED; grpc_exec_ctx_flush(exec_ctx); @@ -103,6 +76,29 @@ static void exec_ctx_run(grpc_exec_ctx *exec_ctx, grpc_closure *closure, GRPC_ERROR_UNREF(error); } +bool grpc_exec_ctx_flush(grpc_exec_ctx *exec_ctx) { + bool did_something = 0; + GPR_TIMER_BEGIN("grpc_exec_ctx_flush", 0); + for (;;) { + if (!grpc_closure_list_empty(exec_ctx->closure_list)) { + grpc_closure *c = exec_ctx->closure_list.head; + exec_ctx->closure_list.head = exec_ctx->closure_list.tail = NULL; + while (c != NULL) { + grpc_closure *next = c->next_data.next; + grpc_error *error = c->error_data.error; + did_something = true; + exec_ctx_run(exec_ctx, c, error); + c = next; + } + } else if (!grpc_combiner_continue_exec_ctx(exec_ctx)) { + break; + } + } + GPR_ASSERT(exec_ctx->active_combiner == NULL); + GPR_TIMER_END("grpc_exec_ctx_flush", 0); + return did_something; +} + static void exec_ctx_sched(grpc_exec_ctx *exec_ctx, grpc_closure *closure, grpc_error *error) { grpc_closure_list_append(&exec_ctx->closure_list, closure, error); diff --git a/src/core/lib/security/transport/security_handshaker.c b/src/core/lib/security/transport/security_handshaker.c index b9da6e16b2..fc9c9f980f 100644 --- a/src/core/lib/security/transport/security_handshaker.c +++ b/src/core/lib/security/transport/security_handshaker.c @@ -261,7 +261,7 @@ static grpc_error *do_handshaker_next_locked( grpc_exec_ctx *exec_ctx, security_handshaker *h, const unsigned char *bytes_received, size_t bytes_received_size) { // Invoke TSI handshaker. - unsigned char *bytes_to_send = NULL; + const unsigned char *bytes_to_send = NULL; size_t bytes_to_send_size = 0; tsi_handshaker_result *handshaker_result = NULL; tsi_result result = tsi_handshaker_next( diff --git a/src/core/lib/surface/completion_queue.c b/src/core/lib/surface/completion_queue.c index 3d82a32e82..c20cfbc740 100644 --- a/src/core/lib/surface/completion_queue.c +++ b/src/core/lib/surface/completion_queue.c @@ -235,7 +235,8 @@ typedef struct cq_next_data { /* Number of outstanding events (+1 if not shut down) */ gpr_atm pending_events; - int shutdown_called; + /** 0 initially. 1 once we initiated shutdown */ + bool shutdown_called; } cq_next_data; typedef struct cq_pluck_data { @@ -244,15 +245,20 @@ typedef struct cq_pluck_data { grpc_cq_completion *completed_tail; /** Number of pending events (+1 if we're not shutdown) */ - gpr_refcount pending_events; + gpr_atm pending_events; /** Counter of how many things have ever been queued on this completion queue useful for avoiding locks to check the queue */ gpr_atm things_queued_ever; - /** 0 initially, 1 once we've begun shutting down */ + /** 0 initially. 1 once we completed shutting */ + /* TODO: (sreek) This is not needed since (shutdown == 1) if and only if + * (pending_events == 0). So consider removing this in future and use + * pending_events */ gpr_atm shutdown; - int shutdown_called; + + /** 0 initially. 1 once we initiated shutdown */ + bool shutdown_called; int num_pluckers; plucker pluckers[GRPC_MAX_COMPLETION_QUEUE_PLUCKERS]; @@ -436,7 +442,7 @@ grpc_completion_queue *grpc_completion_queue_create_internal( static void cq_init_next(void *ptr) { cq_next_data *cqd = ptr; - /* Initial ref is dropped by grpc_completion_queue_shutdown */ + /* Initial count is dropped by grpc_completion_queue_shutdown */ gpr_atm_no_barrier_store(&cqd->pending_events, 1); cqd->shutdown_called = false; gpr_atm_no_barrier_store(&cqd->things_queued_ever, 0); @@ -451,12 +457,12 @@ static void cq_destroy_next(void *ptr) { static void cq_init_pluck(void *ptr) { cq_pluck_data *cqd = ptr; - /* Initial ref is dropped by grpc_completion_queue_shutdown */ - gpr_ref_init(&cqd->pending_events, 1); + /* Initial count is dropped by grpc_completion_queue_shutdown */ + gpr_atm_no_barrier_store(&cqd->pending_events, 1); cqd->completed_tail = &cqd->completed_head; cqd->completed_head.next = (uintptr_t)cqd->completed_tail; gpr_atm_no_barrier_store(&cqd->shutdown, 0); - cqd->shutdown_called = 0; + cqd->shutdown_called = false; cqd->num_pluckers = 0; gpr_atm_no_barrier_store(&cqd->things_queued_ever, 0); } @@ -549,24 +555,32 @@ static void cq_check_tag(grpc_completion_queue *cq, void *tag, bool lock_cq) { static void cq_check_tag(grpc_completion_queue *cq, void *tag, bool lock_cq) {} #endif -static bool cq_begin_op_for_next(grpc_completion_queue *cq, void *tag) { - cq_next_data *cqd = DATA_FROM_CQ(cq); +/* Atomically increments a counter only if the counter is not zero. Returns + * true if the increment was successful; false if the counter is zero */ +static bool atm_inc_if_nonzero(gpr_atm *counter) { while (true) { - gpr_atm count = gpr_atm_no_barrier_load(&cqd->pending_events); + gpr_atm count = gpr_atm_no_barrier_load(counter); + /* If zero, we are done. If not, we must to a CAS (instead of an atomic + * increment) to maintain the contract: do not increment the counter if it + * is zero. */ if (count == 0) { return false; - } else if (gpr_atm_no_barrier_cas(&cqd->pending_events, count, count + 1)) { + } else if (gpr_atm_no_barrier_cas(counter, count, count + 1)) { break; } } + return true; } +static bool cq_begin_op_for_next(grpc_completion_queue *cq, void *tag) { + cq_next_data *cqd = DATA_FROM_CQ(cq); + return atm_inc_if_nonzero(&cqd->pending_events); +} + static bool cq_begin_op_for_pluck(grpc_completion_queue *cq, void *tag) { cq_pluck_data *cqd = DATA_FROM_CQ(cq); - GPR_ASSERT(!cqd->shutdown_called); - gpr_ref(&cqd->pending_events); - return true; + return atm_inc_if_nonzero(&cqd->pending_events); } bool grpc_cq_begin_op(grpc_completion_queue *cq, void *tag) { @@ -704,8 +718,10 @@ static void cq_end_op_for_pluck(grpc_exec_ctx *exec_ctx, ((uintptr_t)storage) | (1u & (uintptr_t)cqd->completed_tail->next); cqd->completed_tail = storage; - int shutdown = gpr_unref(&cqd->pending_events); - if (!shutdown) { + if (gpr_atm_full_fetch_add(&cqd->pending_events, -1) == 1) { + cq_finish_shutdown_pluck(exec_ctx, cq); + gpr_mu_unlock(cq->mu); + } else { grpc_pollset_worker *pluck_worker = NULL; for (int i = 0; i < cqd->num_pluckers; i++) { if (cqd->pluckers[i].tag == tag) { @@ -725,9 +741,6 @@ static void cq_end_op_for_pluck(grpc_exec_ctx *exec_ctx, GRPC_ERROR_UNREF(kick_error); } - } else { - cq_finish_shutdown_pluck(exec_ctx, cq); - gpr_mu_unlock(cq->mu); } GPR_TIMER_END("cq_end_op_for_pluck", 0); @@ -952,6 +965,12 @@ static void cq_shutdown_next(grpc_exec_ctx *exec_ctx, grpc_completion_queue *cq) { cq_next_data *cqd = DATA_FROM_CQ(cq); + /* Need an extra ref for cq here because: + * We call cq_finish_shutdown_next() below, that would call pollset shutdown. + * Pollset shutdown decrements the cq ref count which can potentially destroy + * the cq (if that happens to be the last ref). + * Creating an extra ref here prevents the cq from getting destroyed while + * this function is still active */ GRPC_CQ_INTERNAL_REF(cq, "shutting_down"); gpr_mu_lock(cq->mu); if (cqd->shutdown_called) { @@ -960,7 +979,7 @@ static void cq_shutdown_next(grpc_exec_ctx *exec_ctx, GPR_TIMER_END("grpc_completion_queue_shutdown", 0); return; } - cqd->shutdown_called = 1; + cqd->shutdown_called = true; if (gpr_atm_full_fetch_add(&cqd->pending_events, -1) == 1) { cq_finish_shutdown_next(exec_ctx, cq); } @@ -1172,21 +1191,32 @@ static void cq_finish_shutdown_pluck(grpc_exec_ctx *exec_ctx, &cq->pollset_shutdown_done); } +/* NOTE: This function is almost exactly identical to cq_shutdown_next() but + * merging them is a bit tricky and probably not worth it */ static void cq_shutdown_pluck(grpc_exec_ctx *exec_ctx, grpc_completion_queue *cq) { cq_pluck_data *cqd = DATA_FROM_CQ(cq); + /* Need an extra ref for cq here because: + * We call cq_finish_shutdown_pluck() below, that would call pollset shutdown. + * Pollset shutdown decrements the cq ref count which can potentially destroy + * the cq (if that happens to be the last ref). + * Creating an extra ref here prevents the cq from getting destroyed while + * this function is still active */ + GRPC_CQ_INTERNAL_REF(cq, "shutting_down (pluck cq)"); gpr_mu_lock(cq->mu); if (cqd->shutdown_called) { gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(exec_ctx, cq, "shutting_down (pluck cq)"); GPR_TIMER_END("grpc_completion_queue_shutdown", 0); return; } - cqd->shutdown_called = 1; - if (gpr_unref(&cqd->pending_events)) { + cqd->shutdown_called = true; + if (gpr_atm_full_fetch_add(&cqd->pending_events, -1) == 1) { cq_finish_shutdown_pluck(exec_ctx, cq); } gpr_mu_unlock(cq->mu); + GRPC_CQ_INTERNAL_UNREF(exec_ctx, cq, "shutting_down (pluck cq)"); } /* Shutdown simply drops a ref that we reserved at creation time; if we drop diff --git a/src/core/tsi/fake_transport_security.c b/src/core/tsi/fake_transport_security.c index 810447313c..967126ecee 100644 --- a/src/core/tsi/fake_transport_security.c +++ b/src/core/tsi/fake_transport_security.c @@ -407,8 +407,10 @@ static void fake_handshaker_result_destroy(tsi_handshaker_result *self) { static const tsi_handshaker_result_vtable handshaker_result_vtable = { fake_handshaker_result_extract_peer, + NULL, /* create_zero_copy_grpc_protector */ fake_handshaker_result_create_frame_protector, - fake_handshaker_result_get_unused_bytes, fake_handshaker_result_destroy, + fake_handshaker_result_get_unused_bytes, + fake_handshaker_result_destroy, }; static tsi_result fake_handshaker_result_create( @@ -530,7 +532,7 @@ static void fake_handshaker_destroy(tsi_handshaker *self) { static tsi_result fake_handshaker_next( tsi_handshaker *self, const unsigned char *received_bytes, - size_t received_bytes_size, unsigned char **bytes_to_send, + size_t received_bytes_size, const unsigned char **bytes_to_send, size_t *bytes_to_send_size, tsi_handshaker_result **handshaker_result, tsi_handshaker_on_next_done_cb cb, void *user_data) { /* Sanity check the arguments. */ diff --git a/src/core/tsi/transport_security.c b/src/core/tsi/transport_security.c index 2b1f4310c1..76213072a3 100644 --- a/src/core/tsi/transport_security.c +++ b/src/core/tsi/transport_security.c @@ -74,14 +74,12 @@ tsi_result tsi_frame_protector_protect(tsi_frame_protector *self, size_t *unprotected_bytes_size, unsigned char *protected_output_frames, size_t *protected_output_frames_size) { - if (self == NULL || unprotected_bytes == NULL || + if (self == NULL || self->vtable == NULL || unprotected_bytes == NULL || unprotected_bytes_size == NULL || protected_output_frames == NULL || protected_output_frames_size == NULL) { return TSI_INVALID_ARGUMENT; } - if (self->vtable == NULL || self->vtable->protect == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->protect == NULL) return TSI_UNIMPLEMENTED; return self->vtable->protect(self, unprotected_bytes, unprotected_bytes_size, protected_output_frames, protected_output_frames_size); @@ -90,13 +88,11 @@ tsi_result tsi_frame_protector_protect(tsi_frame_protector *self, tsi_result tsi_frame_protector_protect_flush( tsi_frame_protector *self, unsigned char *protected_output_frames, size_t *protected_output_frames_size, size_t *still_pending_size) { - if (self == NULL || protected_output_frames == NULL || + if (self == NULL || self->vtable == NULL || protected_output_frames == NULL || protected_output_frames_size == NULL || still_pending_size == NULL) { return TSI_INVALID_ARGUMENT; } - if (self->vtable == NULL || self->vtable->protect_flush == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->protect_flush == NULL) return TSI_UNIMPLEMENTED; return self->vtable->protect_flush(self, protected_output_frames, protected_output_frames_size, still_pending_size); @@ -106,14 +102,12 @@ tsi_result tsi_frame_protector_unprotect( tsi_frame_protector *self, const unsigned char *protected_frames_bytes, size_t *protected_frames_bytes_size, unsigned char *unprotected_bytes, size_t *unprotected_bytes_size) { - if (self == NULL || protected_frames_bytes == NULL || + if (self == NULL || self->vtable == NULL || protected_frames_bytes == NULL || protected_frames_bytes_size == NULL || unprotected_bytes == NULL || unprotected_bytes_size == NULL) { return TSI_INVALID_ARGUMENT; } - if (self->vtable == NULL || self->vtable->unprotect == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->unprotect == NULL) return TSI_UNIMPLEMENTED; return self->vtable->unprotect(self, protected_frames_bytes, protected_frames_bytes_size, unprotected_bytes, unprotected_bytes_size); @@ -131,48 +125,44 @@ void tsi_frame_protector_destroy(tsi_frame_protector *self) { tsi_result tsi_handshaker_get_bytes_to_send_to_peer(tsi_handshaker *self, unsigned char *bytes, size_t *bytes_size) { - if (self == NULL || bytes == NULL || bytes_size == NULL) { + if (self == NULL || self->vtable == NULL || bytes == NULL || + bytes_size == NULL) { return TSI_INVALID_ARGUMENT; } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; - if (self->vtable == NULL || self->vtable->get_bytes_to_send_to_peer == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->get_bytes_to_send_to_peer == NULL) return TSI_UNIMPLEMENTED; return self->vtable->get_bytes_to_send_to_peer(self, bytes, bytes_size); } tsi_result tsi_handshaker_process_bytes_from_peer(tsi_handshaker *self, const unsigned char *bytes, size_t *bytes_size) { - if (self == NULL || bytes == NULL || bytes_size == NULL) { + if (self == NULL || self->vtable == NULL || bytes == NULL || + bytes_size == NULL) { return TSI_INVALID_ARGUMENT; } if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; - if (self->vtable == NULL || self->vtable->process_bytes_from_peer == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->process_bytes_from_peer == NULL) return TSI_UNIMPLEMENTED; return self->vtable->process_bytes_from_peer(self, bytes, bytes_size); } tsi_result tsi_handshaker_get_result(tsi_handshaker *self) { - if (self == NULL) return TSI_INVALID_ARGUMENT; + if (self == NULL || self->vtable == NULL) return TSI_INVALID_ARGUMENT; if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; - if (self->vtable == NULL || self->vtable->get_result == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->get_result == NULL) return TSI_UNIMPLEMENTED; return self->vtable->get_result(self); } tsi_result tsi_handshaker_extract_peer(tsi_handshaker *self, tsi_peer *peer) { - if (self == NULL || peer == NULL) return TSI_INVALID_ARGUMENT; + if (self == NULL || self->vtable == NULL || peer == NULL) { + return TSI_INVALID_ARGUMENT; + } memset(peer, 0, sizeof(tsi_peer)); if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; if (tsi_handshaker_get_result(self) != TSI_OK) { return TSI_FAILED_PRECONDITION; } - if (self->vtable == NULL || self->vtable->extract_peer == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->extract_peer == NULL) return TSI_UNIMPLEMENTED; return self->vtable->extract_peer(self, peer); } @@ -180,14 +170,12 @@ tsi_result tsi_handshaker_create_frame_protector( tsi_handshaker *self, size_t *max_protected_frame_size, tsi_frame_protector **protector) { tsi_result result; - if (self == NULL || protector == NULL) return TSI_INVALID_ARGUMENT; - if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; - if (tsi_handshaker_get_result(self) != TSI_OK) { - return TSI_FAILED_PRECONDITION; - } - if (self->vtable == NULL || self->vtable->create_frame_protector == NULL) { - return TSI_UNIMPLEMENTED; + if (self == NULL || self->vtable == NULL || protector == NULL) { + return TSI_INVALID_ARGUMENT; } + if (self->frame_protector_created) return TSI_FAILED_PRECONDITION; + if (tsi_handshaker_get_result(self) != TSI_OK) return TSI_FAILED_PRECONDITION; + if (self->vtable->create_frame_protector == NULL) return TSI_UNIMPLEMENTED; result = self->vtable->create_frame_protector(self, max_protected_frame_size, protector); if (result == TSI_OK) { @@ -198,14 +186,12 @@ tsi_result tsi_handshaker_create_frame_protector( tsi_result tsi_handshaker_next( tsi_handshaker *self, const unsigned char *received_bytes, - size_t received_bytes_size, unsigned char **bytes_to_send, + size_t received_bytes_size, const unsigned char **bytes_to_send, size_t *bytes_to_send_size, tsi_handshaker_result **handshaker_result, tsi_handshaker_on_next_done_cb cb, void *user_data) { - if (self == NULL) return TSI_INVALID_ARGUMENT; + if (self == NULL || self->vtable == NULL) return TSI_INVALID_ARGUMENT; if (self->handshaker_result_created) return TSI_FAILED_PRECONDITION; - if (self->vtable == NULL || self->vtable->next == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->next == NULL) return TSI_UNIMPLEMENTED; return self->vtable->next(self, received_bytes, received_bytes_size, bytes_to_send, bytes_to_send_size, handshaker_result, cb, user_data); @@ -220,21 +206,21 @@ void tsi_handshaker_destroy(tsi_handshaker *self) { tsi_result tsi_handshaker_result_extract_peer(const tsi_handshaker_result *self, tsi_peer *peer) { - if (self == NULL || peer == NULL) return TSI_INVALID_ARGUMENT; - memset(peer, 0, sizeof(tsi_peer)); - if (self->vtable == NULL || self->vtable->extract_peer == NULL) { - return TSI_UNIMPLEMENTED; + if (self == NULL || self->vtable == NULL || peer == NULL) { + return TSI_INVALID_ARGUMENT; } + memset(peer, 0, sizeof(tsi_peer)); + if (self->vtable->extract_peer == NULL) return TSI_UNIMPLEMENTED; return self->vtable->extract_peer(self, peer); } tsi_result tsi_handshaker_result_create_frame_protector( const tsi_handshaker_result *self, size_t *max_protected_frame_size, tsi_frame_protector **protector) { - if (self == NULL || protector == NULL) return TSI_INVALID_ARGUMENT; - if (self->vtable == NULL || self->vtable->create_frame_protector == NULL) { - return TSI_UNIMPLEMENTED; + if (self == NULL || self->vtable == NULL || protector == NULL) { + return TSI_INVALID_ARGUMENT; } + if (self->vtable->create_frame_protector == NULL) return TSI_UNIMPLEMENTED; return self->vtable->create_frame_protector(self, max_protected_frame_size, protector); } @@ -242,12 +228,11 @@ tsi_result tsi_handshaker_result_create_frame_protector( tsi_result tsi_handshaker_result_get_unused_bytes( const tsi_handshaker_result *self, const unsigned char **bytes, size_t *bytes_size) { - if (self == NULL || bytes == NULL || bytes_size == NULL) { + if (self == NULL || self->vtable == NULL || bytes == NULL || + bytes_size == NULL) { return TSI_INVALID_ARGUMENT; } - if (self->vtable == NULL || self->vtable->get_unused_bytes == NULL) { - return TSI_UNIMPLEMENTED; - } + if (self->vtable->get_unused_bytes == NULL) return TSI_UNIMPLEMENTED; return self->vtable->get_unused_bytes(self, bytes, bytes_size); } diff --git a/src/core/tsi/transport_security.h b/src/core/tsi/transport_security.h index 2c7db6bca9..b0d7039850 100644 --- a/src/core/tsi/transport_security.h +++ b/src/core/tsi/transport_security.h @@ -70,7 +70,8 @@ typedef struct { tsi_frame_protector **protector); void (*destroy)(tsi_handshaker *self); tsi_result (*next)(tsi_handshaker *self, const unsigned char *received_bytes, - size_t received_bytes_size, unsigned char **bytes_to_send, + size_t received_bytes_size, + const unsigned char **bytes_to_send, size_t *bytes_to_send_size, tsi_handshaker_result **handshaker_result, tsi_handshaker_on_next_done_cb cb, void *user_data); @@ -86,6 +87,10 @@ struct tsi_handshaker { See transport_security_interface.h for documentation. */ typedef struct { tsi_result (*extract_peer)(const tsi_handshaker_result *self, tsi_peer *peer); + tsi_result (*create_zero_copy_grpc_protector)( + const tsi_handshaker_result *self, + size_t *max_output_protected_frame_size, + tsi_zero_copy_grpc_protector **protector); tsi_result (*create_frame_protector)(const tsi_handshaker_result *self, size_t *max_output_protected_frame_size, tsi_frame_protector **protector); diff --git a/src/core/tsi/transport_security_adapter.c b/src/core/tsi/transport_security_adapter.c index b6dc660c47..1c2a57b3bd 100644 --- a/src/core/tsi/transport_security_adapter.c +++ b/src/core/tsi/transport_security_adapter.c @@ -66,8 +66,11 @@ static void adapter_result_destroy(tsi_handshaker_result *self) { } static const tsi_handshaker_result_vtable result_vtable = { - adapter_result_extract_peer, adapter_result_create_frame_protector, - adapter_result_get_unused_bytes, adapter_result_destroy, + adapter_result_extract_peer, + NULL, /* create_zero_copy_grpc_protector */ + adapter_result_create_frame_protector, + adapter_result_get_unused_bytes, + adapter_result_destroy, }; /* Ownership of wrapped tsi_handshaker is transferred to the result object. */ @@ -140,7 +143,7 @@ static void adapter_destroy(tsi_handshaker *self) { static tsi_result adapter_next( tsi_handshaker *self, const unsigned char *received_bytes, - size_t received_bytes_size, unsigned char **bytes_to_send, + size_t received_bytes_size, const unsigned char **bytes_to_send, size_t *bytes_to_send_size, tsi_handshaker_result **handshaker_result, tsi_handshaker_on_next_done_cb cb, void *user_data) { /* Input sanity check. */ diff --git a/src/core/tsi/transport_security_grpc.c b/src/core/tsi/transport_security_grpc.c new file mode 100644 index 0000000000..5bcfdfa61f --- /dev/null +++ b/src/core/tsi/transport_security_grpc.c @@ -0,0 +1,64 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "src/core/tsi/transport_security_grpc.h" + +/* This method creates a tsi_zero_copy_grpc_protector object. */ +tsi_result tsi_handshaker_result_create_zero_copy_grpc_protector( + const tsi_handshaker_result *self, size_t *max_output_protected_frame_size, + tsi_zero_copy_grpc_protector **protector) { + if (self == NULL || self->vtable == NULL || protector == NULL) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->create_zero_copy_grpc_protector == NULL) { + return TSI_UNIMPLEMENTED; + } + return self->vtable->create_zero_copy_grpc_protector( + self, max_output_protected_frame_size, protector); +} + +/* --- tsi_zero_copy_grpc_protector common implementation. --- + + Calls specific implementation after state/input validation. */ + +tsi_result tsi_zero_copy_grpc_protector_protect( + tsi_zero_copy_grpc_protector *self, grpc_slice_buffer *unprotected_slices, + grpc_slice_buffer *protected_slices) { + if (self == NULL || self->vtable == NULL || unprotected_slices == NULL || + protected_slices == NULL) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->protect == NULL) return TSI_UNIMPLEMENTED; + return self->vtable->protect(self, unprotected_slices, protected_slices); +} + +tsi_result tsi_zero_copy_grpc_protector_unprotect( + tsi_zero_copy_grpc_protector *self, grpc_slice_buffer *protected_slices, + grpc_slice_buffer *unprotected_slices) { + if (self == NULL || self->vtable == NULL || protected_slices == NULL || + unprotected_slices == NULL) { + return TSI_INVALID_ARGUMENT; + } + if (self->vtable->unprotect == NULL) return TSI_UNIMPLEMENTED; + return self->vtable->unprotect(self, protected_slices, unprotected_slices); +} + +void tsi_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector *self) { + if (self == NULL) return; + self->vtable->destroy(self); +} diff --git a/src/core/tsi/transport_security_grpc.h b/src/core/tsi/transport_security_grpc.h new file mode 100644 index 0000000000..5ab5297cc4 --- /dev/null +++ b/src/core/tsi/transport_security_grpc.h @@ -0,0 +1,80 @@ +/* + * + * Copyright 2017 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#ifndef GRPC_CORE_TSI_TRANSPORT_SECURITY_GRPC_H +#define GRPC_CORE_TSI_TRANSPORT_SECURITY_GRPC_H + +#include <grpc/slice_buffer.h> +#include "src/core/tsi/transport_security.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/* This method creates a tsi_zero_copy_grpc_protector object. It return TSI_OK + assuming there is no fatal error. + The caller is responsible for destroying the protector. */ +tsi_result tsi_handshaker_result_create_zero_copy_grpc_protector( + const tsi_handshaker_result *self, size_t *max_output_protected_frame_size, + tsi_zero_copy_grpc_protector **protector); + +/* -- tsi_zero_copy_grpc_protector object -- */ + +/* Outputs protected frames. + - unprotected_slices is the unprotected data to be protected. + - protected_slices is the protected output frames. One or more frames + may be produced in this protect function. + - This method returns TSI_OK in case of success or a specific error code in + case of failure. */ +tsi_result tsi_zero_copy_grpc_protector_protect( + tsi_zero_copy_grpc_protector *self, grpc_slice_buffer *unprotected_slices, + grpc_slice_buffer *protected_slices); + +/* Outputs unprotected bytes. + - protected_slices is the bytes of protected frames. + - unprotected_slices is the unprotected output data. + - This method returns TSI_OK in case of success. Success includes cases where + there is not enough data to output in which case unprotected_slices has 0 + bytes. */ +tsi_result tsi_zero_copy_grpc_protector_unprotect( + tsi_zero_copy_grpc_protector *self, grpc_slice_buffer *protected_slices, + grpc_slice_buffer *unprotected_slices); + +/* Destroys the tsi_zero_copy_grpc_protector object. */ +void tsi_zero_copy_grpc_protector_destroy(tsi_zero_copy_grpc_protector *self); + +/* Base for tsi_zero_copy_grpc_protector implementations. */ +typedef struct { + tsi_result (*protect)(tsi_zero_copy_grpc_protector *self, + grpc_slice_buffer *unprotected_slices, + grpc_slice_buffer *protected_slices); + tsi_result (*unprotect)(tsi_zero_copy_grpc_protector *self, + grpc_slice_buffer *protected_slices, + grpc_slice_buffer *unprotected_slices); + void (*destroy)(tsi_zero_copy_grpc_protector *self); +} tsi_zero_copy_grpc_protector_vtable; + +struct tsi_zero_copy_grpc_protector { + const tsi_zero_copy_grpc_protector_vtable *vtable; +}; + +#ifdef __cplusplus +} +#endif + +#endif /* GRPC_CORE_TSI_TRANSPORT_SECURITY_GRPC_H */ diff --git a/src/core/tsi/transport_security_interface.h b/src/core/tsi/transport_security_interface.h index 39ba8addc4..80c426bbdb 100644 --- a/src/core/tsi/transport_security_interface.h +++ b/src/core/tsi/transport_security_interface.h @@ -62,6 +62,15 @@ const char *tsi_result_to_string(tsi_result result); extern grpc_tracer_flag tsi_tracing_enabled; +/* -- tsi_zero_copy_grpc_protector object -- + + This object protects and unprotects grpc slice buffers with zero or minimized + memory copy once the handshake is done. Implementations of this object must be + thread compatible. This object depends on grpc and the details of this object + is defined in transport_security_grpc.h. */ + +typedef struct tsi_zero_copy_grpc_protector tsi_zero_copy_grpc_protector; + /* --- tsi_frame_protector object --- This object protects and unprotects buffers once the handshake is done. @@ -429,7 +438,7 @@ typedef void (*tsi_handshaker_on_next_done_cb)( tsi_handshaker object. */ tsi_result tsi_handshaker_next( tsi_handshaker *self, const unsigned char *received_bytes, - size_t received_bytes_size, unsigned char **bytes_to_send, + size_t received_bytes_size, const unsigned char **bytes_to_send, size_t *bytes_to_send_size, tsi_handshaker_result **handshaker_result, tsi_handshaker_on_next_done_cb cb, void *user_data); diff --git a/src/node/ext/call.cc b/src/node/ext/call.cc index 71e6904008..26095a78f9 100644 --- a/src/node/ext/call.cc +++ b/src/node/ext/call.cc @@ -260,7 +260,10 @@ class SendClientCloseOp : public Op { class SendServerStatusOp : public Op { public: - SendServerStatusOp() { grpc_metadata_array_init(&status_metadata); } + SendServerStatusOp() { + details = grpc_empty_slice(); + grpc_metadata_array_init(&status_metadata); + } ~SendServerStatusOp() { grpc_slice_unref(details); DestroyMetadataArray(&status_metadata); @@ -381,7 +384,10 @@ class ReadMessageOp : public Op { class ClientStatusOp : public Op { public: - ClientStatusOp() { grpc_metadata_array_init(&metadata_array); } + ClientStatusOp() { + grpc_metadata_array_init(&metadata_array); + status_details = grpc_empty_slice(); + } ~ClientStatusOp() { grpc_metadata_array_destroy(&metadata_array); diff --git a/src/node/test/call_test.js b/src/node/test/call_test.js index aebd298e32..b5246c4f31 100644 --- a/src/node/test/call_test.js +++ b/src/node/test/call_test.js @@ -188,6 +188,103 @@ describe('call', function() { }, TypeError); }); }); + describe('startBatch with message', function() { + it('should fail with null argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_MESSAGE] = null; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail with numeric argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_MESSAGE] = 5; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail with string argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_MESSAGE] = 'value'; + call.startBatch(batch, function(){}); + }, TypeError); + }); + }); + describe('startBatch with status', function() { + it('should fail without a code', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + details: 'details string', + metadata: {} + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail without details', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + code: 0, + metadata: {} + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail without metadata', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + code: 0, + details: 'details string' + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail with incorrectly typed code argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + code: 'code string', + details: 'details string', + metadata: {} + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail with incorrectly typed details argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + code: 0, + details: 5, + metadata: {} + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + it('should fail with incorrectly typed metadata argument', function() { + var call = new grpc.Call(channel, 'method', getDeadline(1)); + assert.throws(function() { + var batch = {}; + batch[grpc.opType.SEND_STATUS_FROM_SERVER] = { + code: 0, + details: 'details string', + metadata: 'abc' + }; + call.startBatch(batch, function(){}); + }, TypeError); + }); + }); describe('cancel', function() { it('should succeed', function() { var call = new grpc.Call(channel, 'method', getDeadline(1)); diff --git a/src/objective-c/README.md b/src/objective-c/README.md index 3624475b9c..e76ee173ea 100644 --- a/src/objective-c/README.md +++ b/src/objective-c/README.md @@ -112,7 +112,7 @@ the sample Podspec above. For example, you could use: ```ruby s.prepare_command = <<-CMD ... - #{src}/*.proto #{src}/**/*.proto + `find . -name *.proto -print | xargs` CMD ... ms.source_files = "#{dir}/*.pbobjc.{h,m}", "#{dir}/**/*.pbobjc.{h,m}" diff --git a/src/python/grpcio/grpc_core_dependencies.py b/src/python/grpcio/grpc_core_dependencies.py index e52d43e81d..dc4d28f95b 100644 --- a/src/python/grpcio/grpc_core_dependencies.py +++ b/src/python/grpcio/grpc_core_dependencies.py @@ -246,6 +246,7 @@ CORE_SOURCE_FILES = [ 'src/core/tsi/fake_transport_security.c', 'src/core/tsi/gts_transport_security.c', 'src/core/tsi/ssl_transport_security.c', + 'src/core/tsi/transport_security_grpc.c', 'src/core/tsi/transport_security.c', 'src/core/tsi/transport_security_adapter.c', 'src/core/ext/transport/chttp2/server/chttp2_server.c', diff --git a/src/ruby/ext/grpc/rb_call.c b/src/ruby/ext/grpc/rb_call.c index b99954883f..74f189e1e0 100644 --- a/src/ruby/ext/grpc/rb_call.c +++ b/src/ruby/ext/grpc/rb_call.c @@ -179,6 +179,38 @@ static VALUE grpc_rb_call_cancel(VALUE self) { return Qnil; } +/* TODO: expose this as part of the surface API if needed. + * This is meant for internal usage by the "write thread" of grpc-ruby + * client-side bidi calls. It provides a way for the background write-thread + * to propogate failures to the main read-thread and give the user an error + * message. */ +static VALUE grpc_rb_call_cancel_with_status(VALUE self, VALUE status_code, + VALUE details) { + grpc_rb_call *call = NULL; + grpc_call_error err; + if (RTYPEDDATA_DATA(self) == NULL) { + // This call has been closed + return Qnil; + } + + if (TYPE(details) != T_STRING || TYPE(status_code) != T_FIXNUM) { + rb_raise(rb_eTypeError, + "Bad parameter type error for cancel with status. Want Fixnum, " + "String."); + return Qnil; + } + + TypedData_Get_Struct(self, grpc_rb_call, &grpc_call_data_type, call); + err = grpc_call_cancel_with_status(call->wrapped, NUM2LONG(status_code), + StringValueCStr(details), NULL); + if (err != GRPC_CALL_OK) { + rb_raise(grpc_rb_eCallError, "cancel with status failed: %s (code=%d)", + grpc_call_error_detail_of(err), err); + } + + return Qnil; +} + /* Releases the c-level resources associated with a call Once a call has been closed, no further requests can be processed. @@ -949,6 +981,8 @@ void Init_grpc_call() { /* Add ruby analogues of the Call methods. */ rb_define_method(grpc_rb_cCall, "run_batch", grpc_rb_call_run_batch, 1); rb_define_method(grpc_rb_cCall, "cancel", grpc_rb_call_cancel, 0); + rb_define_method(grpc_rb_cCall, "cancel_with_status", + grpc_rb_call_cancel_with_status, 2); rb_define_method(grpc_rb_cCall, "close", grpc_rb_call_close, 0); rb_define_method(grpc_rb_cCall, "peer", grpc_rb_call_get_peer, 0); rb_define_method(grpc_rb_cCall, "peer_cert", grpc_rb_call_get_peer_cert, 0); diff --git a/src/ruby/lib/grpc/generic/bidi_call.rb b/src/ruby/lib/grpc/generic/bidi_call.rb index 9e125cd986..c2239d0178 100644 --- a/src/ruby/lib/grpc/generic/bidi_call.rb +++ b/src/ruby/lib/grpc/generic/bidi_call.rb @@ -153,7 +153,12 @@ module GRPC rescue StandardError => e GRPC.logger.warn('bidi-write-loop: failed') GRPC.logger.warn(e) - raise e + if is_client + @call.cancel_with_status(GRPC::Core::StatusCodes::UNKNOWN, + "GRPC bidi call error: #{e.inspect}") + else + raise e + end ensure set_output_stream_done.call if is_client end @@ -180,8 +185,8 @@ module GRPC batch_result = @call.run_batch(RECV_STATUS_ON_CLIENT => nil) @call.status = batch_result.status @call.trailing_metadata = @call.status.metadata if @call.status - batch_result.check_status GRPC.logger.debug("bidi-read-loop: done status #{@call.status}") + batch_result.check_status end GRPC.logger.debug('bidi-read-loop: done reading!') diff --git a/src/ruby/spec/call_spec.rb b/src/ruby/spec/call_spec.rb index 473ff4a8bd..1cc0500242 100644 --- a/src/ruby/spec/call_spec.rb +++ b/src/ruby/spec/call_spec.rb @@ -137,6 +137,39 @@ describe GRPC::Core::Call do end end + describe '#cancel' do + it 'completes ok' do + call = make_test_call + expect { call.cancel }.not_to raise_error + end + + it 'completes ok when the call is closed' do + call = make_test_call + call.close + expect { call.cancel }.not_to raise_error + end + end + + describe '#cancel_with_status' do + it 'completes ok' do + call = make_test_call + expect do + call.cancel_with_status(0, 'test status') + end.not_to raise_error + expect do + call.cancel_with_status(0, nil) + end.to raise_error(TypeError) + end + + it 'completes ok when the call is closed' do + call = make_test_call + call.close + expect do + call.cancel_with_status(0, 'test status') + end.not_to raise_error + end + end + def make_test_call @ch.create_call(nil, nil, 'dummy_method', nil, deadline) end diff --git a/src/ruby/spec/client_server_spec.rb b/src/ruby/spec/client_server_spec.rb index b48b4179ce..1a9b47e2c3 100644 --- a/src/ruby/spec/client_server_spec.rb +++ b/src/ruby/spec/client_server_spec.rb @@ -226,6 +226,62 @@ shared_examples 'basic GRPC message delivery is OK' do svr_batch = server_call.run_batch(server_ops) expect(svr_batch.send_close).to be true end + + def client_cancel_test(cancel_proc, expected_code, + expected_details) + call = new_client_call + server_call = nil + + server_thread = Thread.new do + server_call = server_allows_client_to_proceed + end + + client_ops = { + CallOps::SEND_INITIAL_METADATA => {}, + CallOps::RECV_INITIAL_METADATA => nil + } + batch_result = call.run_batch(client_ops) + expect(batch_result.send_metadata).to be true + expect(batch_result.metadata).to eq({}) + + cancel_proc.call(call) + + server_thread.join + server_ops = { + CallOps::RECV_CLOSE_ON_SERVER => nil + } + svr_batch = server_call.run_batch(server_ops) + expect(svr_batch.send_close).to be true + + client_ops = { + CallOps::RECV_STATUS_ON_CLIENT => {} + } + batch_result = call.run_batch(client_ops) + + expect(batch_result.status.code).to be expected_code + expect(batch_result.status.details).to eq expected_details + end + + it 'clients can cancel a call on the server' do + expected_code = StatusCodes::CANCELLED + expected_details = 'Cancelled' + cancel_proc = proc { |call| call.cancel } + client_cancel_test(cancel_proc, expected_code, expected_details) + end + + it 'cancel_with_status unknown status' do + code = StatusCodes::UNKNOWN + details = 'test unknown reason' + cancel_proc = proc { |call| call.cancel_with_status(code, details) } + client_cancel_test(cancel_proc, code, details) + end + + it 'cancel_with_status unknown status' do + code = StatusCodes::FAILED_PRECONDITION + details = 'test failed precondition reason' + cancel_proc = proc { |call| call.cancel_with_status(code, details) } + client_cancel_test(cancel_proc, code, details) + end end shared_examples 'GRPC metadata delivery works OK' do diff --git a/src/ruby/spec/generic/client_stub_spec.rb b/src/ruby/spec/generic/client_stub_spec.rb index e1e7a535fb..9539e56c0f 100644 --- a/src/ruby/spec/generic/client_stub_spec.rb +++ b/src/ruby/spec/generic/client_stub_spec.rb @@ -472,7 +472,7 @@ describe 'ClientStub' do host = "localhost:#{server_port}" stub = GRPC::ClientStub.new(host, :this_channel_is_insecure) expect do - get_responses(stub) + get_responses(stub).collect { |r| r } end.to raise_error(ArgumentError, /Header values must be of type string or array/) end @@ -641,11 +641,101 @@ describe 'ClientStub' do expect(e.collect { |r| r }).to eq(@sent_msgs) th.join end + + # Prompted by grpc/github #10526 + describe 'surfacing of errors when sending requests' do + def run_server_bidi_send_one_then_read_indefinitely + @server.start + recvd_rpc = @server.request_call + recvd_call = recvd_rpc.call + server_call = GRPC::ActiveCall.new( + recvd_call, noop, noop, INFINITE_FUTURE, + metadata_received: true, started: false) + server_call.send_initial_metadata + server_call.remote_send('server response') + loop do + m = server_call.remote_read + break if m.nil? + end + # can't fail since initial metadata already sent + server_call.send_status(@pass, 'OK', true) + end + + def verify_error_from_write_thread(stub, requests_to_push, + request_queue, expected_description) + # TODO: an improvement might be to raise the original exception from + # bidi call write loops instead of only cancelling the call + failing_marshal_proc = proc do |req| + fail req if req.is_a?(StandardError) + req + end + begin + e = get_responses(stub, marshal_proc: failing_marshal_proc) + first_response = e.next + expect(first_response).to eq('server response') + requests_to_push.each { |req| request_queue.push(req) } + e.collect { |r| r } + rescue GRPC::Unknown => e + exception = e + end + expect(exception.message.include?(expected_description)).to be(true) + end + + # Provides an Enumerable view of a Queue + class BidiErrorTestingEnumerateForeverQueue + def initialize(queue) + @queue = queue + end + + def each + loop do + msg = @queue.pop + yield msg + end + end + end + + def run_error_in_client_request_stream_test(requests_to_push, + expected_error_message) + # start a server that waits on a read indefinitely - it should + # see a cancellation and be able to break out + th = Thread.new { run_server_bidi_send_one_then_read_indefinitely } + stub = GRPC::ClientStub.new(@host, :this_channel_is_insecure) + + request_queue = Queue.new + @sent_msgs = BidiErrorTestingEnumerateForeverQueue.new(request_queue) + + verify_error_from_write_thread(stub, + requests_to_push, + request_queue, + expected_error_message) + # the write loop errror should cancel the call and end the + # server's request stream + th.join + end + + it 'non-GRPC errors from the write loop surface when raised ' \ + 'at the start of a request stream' do + expected_error_message = 'expect error on first request' + requests_to_push = [StandardError.new(expected_error_message)] + run_error_in_client_request_stream_test(requests_to_push, + expected_error_message) + end + + it 'non-GRPC errors from the write loop surface when raised ' \ + 'during the middle of a request stream' do + expected_error_message = 'expect error on last request' + requests_to_push = %w( one two ) + requests_to_push << StandardError.new(expected_error_message) + run_error_in_client_request_stream_test(requests_to_push, + expected_error_message) + end + end end describe 'without a call operation' do - def get_responses(stub, deadline: nil) - e = stub.bidi_streamer(@method, @sent_msgs, noop, noop, + def get_responses(stub, deadline: nil, marshal_proc: noop) + e = stub.bidi_streamer(@method, @sent_msgs, marshal_proc, noop, metadata: @metadata, deadline: deadline) expect(e).to be_a(Enumerator) e @@ -658,8 +748,9 @@ describe 'ClientStub' do after(:each) do @op.wait # make sure wait doesn't hang end - def get_responses(stub, run_start_call_first: false, deadline: nil) - @op = stub.bidi_streamer(@method, @sent_msgs, noop, noop, + def get_responses(stub, run_start_call_first: false, deadline: nil, + marshal_proc: noop) + @op = stub.bidi_streamer(@method, @sent_msgs, marshal_proc, noop, return_op: true, metadata: @metadata, deadline: deadline) expect(@op).to be_a(GRPC::ActiveCall::Operation) diff --git a/src/ruby/spec/generic/rpc_server_spec.rb b/src/ruby/spec/generic/rpc_server_spec.rb index e4fe594e22..b887eaaf4e 100644 --- a/src/ruby/spec/generic/rpc_server_spec.rb +++ b/src/ruby/spec/generic/rpc_server_spec.rb @@ -178,6 +178,18 @@ end CheckCallAfterFinishedServiceStub = CheckCallAfterFinishedService.rpc_stub_class +# A service with a bidi streaming method. +class BidiService + include GRPC::GenericService + rpc :server_sends_bad_input, stream(EchoMsg), stream(EchoMsg) + + def server_sends_bad_input(_, _) + 'bad response. (not an enumerable, client sees an error)' + end +end + +BidiStub = BidiService.rpc_stub_class + describe GRPC::RpcServer do RpcServer = GRPC::RpcServer StatusCodes = GRPC::Core::StatusCodes @@ -520,6 +532,29 @@ describe GRPC::RpcServer do t.join expect(one_failed_as_unavailable).to be(true) end + + it 'should send a status UNKNOWN with a relevant message when the' \ + 'servers response stream is not an enumerable' do + @srv.handle(BidiService) + t = Thread.new { @srv.run } + @srv.wait_till_running + stub = BidiStub.new(@host, :this_channel_is_insecure, **client_opts) + responses = stub.server_sends_bad_input([]) + exception = nil + begin + responses.each { |r| r } + rescue GRPC::Unknown => e + exception = e + end + # Erroneous responses sent from the server handler should cause an + # exception on the client with relevant info. + expected_details = 'NoMethodError: undefined method `each\' for '\ + '"bad response. (not an enumerable, client sees an error)"' + + expect(exception.inspect.include?(expected_details)).to be true + @srv.stop + t.join + end end context 'with connect metadata' do |