diff options
author | 2015-08-31 11:36:34 -0700 | |
---|---|---|
committer | 2015-08-31 11:36:34 -0700 | |
commit | ca1bbd3c8c548d81cc6e5a212d346e6647993973 (patch) | |
tree | 7a934159ae9d3729ea9252d0468df18f02530634 /src | |
parent | fb72e1e487edfd7127e0d415e90db55145e37a04 (diff) | |
parent | 50aa6f682e405c7c71a4e7e30b33f03b9d44a972 (diff) |
Merge branch 'master' of github.com:grpc/grpc into cpp_docs
Diffstat (limited to 'src')
86 files changed, 3124 insertions, 598 deletions
diff --git a/src/compiler/python_generator.cc b/src/compiler/python_generator.cc index 72c457ac6b..fe2b9fad99 100644 --- a/src/compiler/python_generator.cc +++ b/src/compiler/python_generator.cc @@ -148,8 +148,8 @@ class IndentScope { // END FORMATTING BOILERPLATE // //////////////////////////////// -bool PrintServicer(const ServiceDescriptor* service, - Printer* out) { +bool PrintAlphaServicer(const ServiceDescriptor* service, + Printer* out) { grpc::string doc = "<fill me in later!>"; map<grpc::string, grpc::string> dict = ListToDict({ "Service", service->name(), @@ -176,7 +176,7 @@ bool PrintServicer(const ServiceDescriptor* service, return true; } -bool PrintServer(const ServiceDescriptor* service, Printer* out) { +bool PrintAlphaServer(const ServiceDescriptor* service, Printer* out) { grpc::string doc = "<fill me in later!>"; map<grpc::string, grpc::string> dict = ListToDict({ "Service", service->name(), @@ -204,8 +204,8 @@ bool PrintServer(const ServiceDescriptor* service, Printer* out) { return true; } -bool PrintStub(const ServiceDescriptor* service, - Printer* out) { +bool PrintAlphaStub(const ServiceDescriptor* service, + Printer* out) { grpc::string doc = "<fill me in later!>"; map<grpc::string, grpc::string> dict = ListToDict({ "Service", service->name(), @@ -268,8 +268,8 @@ bool GetModuleAndMessagePath(const Descriptor* type, return true; } -bool PrintServerFactory(const grpc::string& package_qualified_service_name, - const ServiceDescriptor* service, Printer* out) { +bool PrintAlphaServerFactory(const grpc::string& package_qualified_service_name, + const ServiceDescriptor* service, Printer* out) { out->Print("def early_adopter_create_$Service$_server(servicer, port, " "private_key=None, certificate_chain=None):\n", "Service", service->name()); @@ -320,7 +320,7 @@ bool PrintServerFactory(const grpc::string& package_qualified_service_name, input_message_modules_and_classes.find(method_name); auto output_message_module_and_class = output_message_modules_and_classes.find(method_name); - out->Print("\"$Method$\": utilities.$Constructor$(\n", "Method", + out->Print("\"$Method$\": alpha_utilities.$Constructor$(\n", "Method", method_name, "Constructor", name_and_description_constructor->second); { @@ -348,8 +348,8 @@ bool PrintServerFactory(const grpc::string& package_qualified_service_name, return true; } -bool PrintStubFactory(const grpc::string& package_qualified_service_name, - const ServiceDescriptor* service, Printer* out) { +bool PrintAlphaStubFactory(const grpc::string& package_qualified_service_name, + const ServiceDescriptor* service, Printer* out) { map<grpc::string, grpc::string> dict = ListToDict({ "Service", service->name(), }); @@ -404,7 +404,7 @@ bool PrintStubFactory(const grpc::string& package_qualified_service_name, input_message_modules_and_classes.find(method_name); auto output_message_module_and_class = output_message_modules_and_classes.find(method_name); - out->Print("\"$Method$\": utilities.$Constructor$(\n", "Method", + out->Print("\"$Method$\": alpha_utilities.$Constructor$(\n", "Method", method_name, "Constructor", name_and_description_constructor->second); { @@ -434,12 +434,280 @@ bool PrintStubFactory(const grpc::string& package_qualified_service_name, return true; } +bool PrintBetaServicer(const ServiceDescriptor* service, + Printer* out) { + grpc::string doc = "<fill me in later!>"; + map<grpc::string, grpc::string> dict = ListToDict({ + "Service", service->name(), + "Documentation", doc, + }); + out->Print("\n"); + out->Print(dict, "class Beta$Service$Servicer(object):\n"); + { + IndentScope raii_class_indent(out); + out->Print(dict, "\"\"\"$Documentation$\"\"\"\n"); + out->Print("__metaclass__ = abc.ABCMeta\n"); + for (int i = 0; i < service->method_count(); ++i) { + auto meth = service->method(i); + grpc::string arg_name = meth->client_streaming() ? + "request_iterator" : "request"; + out->Print("@abc.abstractmethod\n"); + out->Print("def $Method$(self, $ArgName$, context):\n", + "Method", meth->name(), "ArgName", arg_name); + { + IndentScope raii_method_indent(out); + out->Print("raise NotImplementedError()\n"); + } + } + } + return true; +} + +bool PrintBetaStub(const ServiceDescriptor* service, + Printer* out) { + grpc::string doc = "The interface to which stubs will conform."; + map<grpc::string, grpc::string> dict = ListToDict({ + "Service", service->name(), + "Documentation", doc, + }); + out->Print("\n"); + out->Print(dict, "class Beta$Service$Stub(object):\n"); + { + IndentScope raii_class_indent(out); + out->Print(dict, "\"\"\"$Documentation$\"\"\"\n"); + out->Print("__metaclass__ = abc.ABCMeta\n"); + for (int i = 0; i < service->method_count(); ++i) { + const MethodDescriptor* meth = service->method(i); + grpc::string arg_name = meth->client_streaming() ? + "request_iterator" : "request"; + auto methdict = ListToDict({"Method", meth->name(), "ArgName", arg_name}); + out->Print("@abc.abstractmethod\n"); + out->Print(methdict, "def $Method$(self, $ArgName$, timeout):\n"); + { + IndentScope raii_method_indent(out); + out->Print("raise NotImplementedError()\n"); + } + if (!meth->server_streaming()) { + out->Print(methdict, "$Method$.future = None\n"); + } + } + } + return true; +} + +bool PrintBetaServerFactory(const grpc::string& package_qualified_service_name, + const ServiceDescriptor* service, Printer* out) { + out->Print("\n"); + out->Print("def beta_create_$Service$_server(servicer, pool=None, " + "pool_size=None, default_timeout=None, maximum_timeout=None):\n", + "Service", service->name()); + { + IndentScope raii_create_server_indent(out); + map<grpc::string, grpc::string> method_implementation_constructors; + map<grpc::string, pair<grpc::string, grpc::string>> + input_message_modules_and_classes; + map<grpc::string, pair<grpc::string, grpc::string>> + output_message_modules_and_classes; + for (int i = 0; i < service->method_count(); ++i) { + const MethodDescriptor* method = service->method(i); + const grpc::string method_implementation_constructor = + grpc::string(method->client_streaming() ? "stream_" : "unary_") + + grpc::string(method->server_streaming() ? "stream_" : "unary_") + + "inline"; + pair<grpc::string, grpc::string> input_message_module_and_class; + if (!GetModuleAndMessagePath(method->input_type(), + &input_message_module_and_class)) { + return false; + } + pair<grpc::string, grpc::string> output_message_module_and_class; + if (!GetModuleAndMessagePath(method->output_type(), + &output_message_module_and_class)) { + return false; + } + // Import the modules that define the messages used in RPCs. + out->Print("import $Module$\n", "Module", + input_message_module_and_class.first); + out->Print("import $Module$\n", "Module", + output_message_module_and_class.first); + method_implementation_constructors.insert( + make_pair(method->name(), method_implementation_constructor)); + input_message_modules_and_classes.insert( + make_pair(method->name(), input_message_module_and_class)); + output_message_modules_and_classes.insert( + make_pair(method->name(), output_message_module_and_class)); + } + out->Print("request_deserializers = {\n"); + for (auto name_and_input_module_class_pair = + input_message_modules_and_classes.begin(); + name_and_input_module_class_pair != + input_message_modules_and_classes.end(); + name_and_input_module_class_pair++) { + IndentScope raii_indent(out); + out->Print("(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$InputTypeModule$.$InputTypeClass$.FromString,\n", + "PackageQualifiedServiceName", package_qualified_service_name, + "MethodName", name_and_input_module_class_pair->first, + "InputTypeModule", + name_and_input_module_class_pair->second.first, + "InputTypeClass", + name_and_input_module_class_pair->second.second); + } + out->Print("}\n"); + out->Print("response_serializers = {\n"); + for (auto name_and_output_module_class_pair = + output_message_modules_and_classes.begin(); + name_and_output_module_class_pair != + output_message_modules_and_classes.end(); + name_and_output_module_class_pair++) { + IndentScope raii_indent(out); + out->Print("(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$OutputTypeModule$.$OutputTypeClass$.SerializeToString,\n", + "PackageQualifiedServiceName", package_qualified_service_name, + "MethodName", name_and_output_module_class_pair->first, + "OutputTypeModule", + name_and_output_module_class_pair->second.first, + "OutputTypeClass", + name_and_output_module_class_pair->second.second); + } + out->Print("}\n"); + out->Print("method_implementations = {\n"); + for (auto name_and_implementation_constructor = + method_implementation_constructors.begin(); + name_and_implementation_constructor != + method_implementation_constructors.end(); + name_and_implementation_constructor++) { + IndentScope raii_descriptions_indent(out); + const grpc::string method_name = + name_and_implementation_constructor->first; + out->Print("(\'$PackageQualifiedServiceName$\', \'$Method$\'): " + "face_utilities.$Constructor$(servicer.$Method$),\n", + "PackageQualifiedServiceName", package_qualified_service_name, + "Method", name_and_implementation_constructor->first, + "Constructor", name_and_implementation_constructor->second); + } + out->Print("}\n"); + out->Print("server_options = beta.server_options(" + "request_deserializers=request_deserializers, " + "response_serializers=response_serializers, " + "thread_pool=pool, thread_pool_size=pool_size, " + "default_timeout=default_timeout, " + "maximum_timeout=maximum_timeout)\n"); + out->Print("return beta.server(method_implementations, " + "options=server_options)\n"); + } + return true; +} + +bool PrintBetaStubFactory(const grpc::string& package_qualified_service_name, + const ServiceDescriptor* service, Printer* out) { + map<grpc::string, grpc::string> dict = ListToDict({ + "Service", service->name(), + }); + out->Print("\n"); + out->Print(dict, "def beta_create_$Service$_stub(channel, host=None," + " metadata_transformer=None, pool=None, pool_size=None):\n"); + { + IndentScope raii_create_server_indent(out); + map<grpc::string, grpc::string> method_cardinalities; + map<grpc::string, pair<grpc::string, grpc::string>> + input_message_modules_and_classes; + map<grpc::string, pair<grpc::string, grpc::string>> + output_message_modules_and_classes; + for (int i = 0; i < service->method_count(); ++i) { + const MethodDescriptor* method = service->method(i); + const grpc::string method_cardinality = + grpc::string(method->client_streaming() ? "STREAM" : "UNARY") + + "_" + + grpc::string(method->server_streaming() ? "STREAM" : "UNARY"); + pair<grpc::string, grpc::string> input_message_module_and_class; + if (!GetModuleAndMessagePath(method->input_type(), + &input_message_module_and_class)) { + return false; + } + pair<grpc::string, grpc::string> output_message_module_and_class; + if (!GetModuleAndMessagePath(method->output_type(), + &output_message_module_and_class)) { + return false; + } + // Import the modules that define the messages used in RPCs. + out->Print("import $Module$\n", "Module", + input_message_module_and_class.first); + out->Print("import $Module$\n", "Module", + output_message_module_and_class.first); + method_cardinalities.insert( + make_pair(method->name(), method_cardinality)); + input_message_modules_and_classes.insert( + make_pair(method->name(), input_message_module_and_class)); + output_message_modules_and_classes.insert( + make_pair(method->name(), output_message_module_and_class)); + } + out->Print("request_serializers = {\n"); + for (auto name_and_input_module_class_pair = + input_message_modules_and_classes.begin(); + name_and_input_module_class_pair != + input_message_modules_and_classes.end(); + name_and_input_module_class_pair++) { + IndentScope raii_indent(out); + out->Print("(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$InputTypeModule$.$InputTypeClass$.SerializeToString,\n", + "PackageQualifiedServiceName", package_qualified_service_name, + "MethodName", name_and_input_module_class_pair->first, + "InputTypeModule", + name_and_input_module_class_pair->second.first, + "InputTypeClass", + name_and_input_module_class_pair->second.second); + } + out->Print("}\n"); + out->Print("response_deserializers = {\n"); + for (auto name_and_output_module_class_pair = + output_message_modules_and_classes.begin(); + name_and_output_module_class_pair != + output_message_modules_and_classes.end(); + name_and_output_module_class_pair++) { + IndentScope raii_indent(out); + out->Print("(\'$PackageQualifiedServiceName$\', \'$MethodName$\'): " + "$OutputTypeModule$.$OutputTypeClass$.FromString,\n", + "PackageQualifiedServiceName", package_qualified_service_name, + "MethodName", name_and_output_module_class_pair->first, + "OutputTypeModule", + name_and_output_module_class_pair->second.first, + "OutputTypeClass", + name_and_output_module_class_pair->second.second); + } + out->Print("}\n"); + out->Print("cardinalities = {\n"); + for (auto name_and_cardinality = method_cardinalities.begin(); + name_and_cardinality != method_cardinalities.end(); + name_and_cardinality++) { + IndentScope raii_descriptions_indent(out); + out->Print("\'$Method$\': cardinality.Cardinality.$Cardinality$,\n", + "Method", name_and_cardinality->first, + "Cardinality", name_and_cardinality->second); + } + out->Print("}\n"); + out->Print("stub_options = beta.stub_options(" + "host=host, metadata_transformer=metadata_transformer, " + "request_serializers=request_serializers, " + "response_deserializers=response_deserializers, " + "thread_pool=pool, thread_pool_size=pool_size)\n"); + out->Print( + "return beta.dynamic_stub(channel, \'$PackageQualifiedServiceName$\', " + "cardinalities, options=stub_options)\n", + "PackageQualifiedServiceName", package_qualified_service_name); + } + return true; +} + bool PrintPreamble(const FileDescriptor* file, const GeneratorConfiguration& config, Printer* out) { out->Print("import abc\n"); + out->Print("from $Package$ import beta\n", + "Package", config.beta_package_root); out->Print("from $Package$ import implementations\n", - "Package", config.implementations_package_root); - out->Print("from grpc.framework.alpha import utilities\n"); + "Package", config.early_adopter_package_root); + out->Print("from grpc.framework.alpha import utilities as alpha_utilities\n"); + out->Print("from grpc.framework.common import cardinality\n"); + out->Print("from grpc.framework.interfaces.face import utilities as face_utilities\n"); return true; } @@ -462,11 +730,15 @@ pair<bool, grpc::string> GetServices(const FileDescriptor* file, for (int i = 0; i < file->service_count(); ++i) { auto service = file->service(i); auto package_qualified_service_name = package + service->name(); - if (!(PrintServicer(service, &out) && - PrintServer(service, &out) && - PrintStub(service, &out) && - PrintServerFactory(package_qualified_service_name, service, &out) && - PrintStubFactory(package_qualified_service_name, service, &out))) { + if (!(PrintAlphaServicer(service, &out) && + PrintAlphaServer(service, &out) && + PrintAlphaStub(service, &out) && + PrintAlphaServerFactory(package_qualified_service_name, service, &out) && + PrintAlphaStubFactory(package_qualified_service_name, service, &out) && + PrintBetaServicer(service, &out) && + PrintBetaStub(service, &out) && + PrintBetaServerFactory(package_qualified_service_name, service, &out) && + PrintBetaStubFactory(package_qualified_service_name, service, &out))) { return make_pair(false, ""); } } diff --git a/src/compiler/python_generator.h b/src/compiler/python_generator.h index b47f3c1243..44ed4b3f98 100644 --- a/src/compiler/python_generator.h +++ b/src/compiler/python_generator.h @@ -43,7 +43,8 @@ namespace grpc_python_generator { // Data pertaining to configuration of the generator with respect to anything // that may be used internally at Google. struct GeneratorConfiguration { - grpc::string implementations_package_root; + grpc::string early_adopter_package_root; + grpc::string beta_package_root; }; class PythonGrpcGenerator : public grpc::protobuf::compiler::CodeGenerator { diff --git a/src/compiler/python_plugin.cc b/src/compiler/python_plugin.cc index d1f49442da..c7cef54900 100644 --- a/src/compiler/python_plugin.cc +++ b/src/compiler/python_plugin.cc @@ -38,7 +38,8 @@ int main(int argc, char* argv[]) { grpc_python_generator::GeneratorConfiguration config; - config.implementations_package_root = "grpc.early_adopter"; + config.early_adopter_package_root = "grpc.early_adopter"; + config.beta_package_root = "grpc.beta"; grpc_python_generator::PythonGrpcGenerator generator(config); return grpc::protobuf::compiler::PluginMain(argc, argv, &generator); } diff --git a/src/core/census/context.c b/src/core/census/context.c index df238ec98c..cab58b653c 100644 --- a/src/core/census/context.c +++ b/src/core/census/context.c @@ -44,16 +44,3 @@ size_t census_context_serialize(const census_context *context, char *buffer, /* TODO(aveitch): implement serialization */ return 0; } - -int census_context_deserialize(const char *buffer, census_context **context) { - int ret = 0; - if (buffer != NULL) { - /* TODO(aveitch): implement deserialization. */ - ret = 1; - } - *context = gpr_malloc(sizeof(census_context)); - memset(*context, 0, sizeof(census_context)); - return ret; -} - -void census_context_destroy(census_context *context) { gpr_free(context); } diff --git a/src/core/census/grpc_context.c b/src/core/census/grpc_context.c index 11f1eb3d5d..429f3ec9db 100644 --- a/src/core/census/grpc_context.c +++ b/src/core/census/grpc_context.c @@ -35,24 +35,11 @@ #include <grpc/grpc.h> #include "src/core/surface/call.h" -static void grpc_census_context_destroy(void *context) { - census_context_destroy((census_context *)context); -} - void grpc_census_call_set_context(grpc_call *call, census_context *context) { if (census_enabled() == CENSUS_FEATURE_NONE) { return; } - if (context == NULL) { - if (grpc_call_is_client(call)) { - census_context *context_ptr; - census_context_deserialize(NULL, &context_ptr); - grpc_call_context_set(call, GRPC_CONTEXT_TRACING, context_ptr, - grpc_census_context_destroy); - } else { - /* TODO(aveitch): server side context code to be implemented. */ - } - } else { + if (context != NULL) { grpc_call_context_set(call, GRPC_CONTEXT_TRACING, context, NULL); } } diff --git a/src/core/census/operation.c b/src/core/census/operation.c new file mode 100644 index 0000000000..118eb0a47a --- /dev/null +++ b/src/core/census/operation.c @@ -0,0 +1,63 @@ +/* + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <grpc/census.h> + +/* TODO(aveitch): These are all placeholder implementations. */ + +census_timestamp census_start_rpc_op_timestamp(void) { + census_timestamp ct; + /* TODO(aveitch): assumes gpr_timespec implementation of census_timestamp. */ + ct.ts = gpr_now(GPR_CLOCK_MONOTONIC); + return ct; +} + +census_context *census_start_client_rpc_op( + const census_context *context, gpr_int64 rpc_name_id, + const census_rpc_name_info *rpc_name_info, const char *peer, int trace_mask, + const census_timestamp *start_time) { + return NULL; +} + +census_context *census_start_server_rpc_op( + const char *buffer, gpr_int64 rpc_name_id, + const census_rpc_name_info *rpc_name_info, const char *peer, int trace_mask, + census_timestamp *start_time) { + return NULL; +} + +census_context *census_start_op(census_context *context, const char *family, + const char *name, int trace_mask) { + return NULL; +} + +void census_end_op(census_context *context, int status) {} diff --git a/src/core/census/tracing.c b/src/core/census/tracing.c new file mode 100644 index 0000000000..ae38773c0a --- /dev/null +++ b/src/core/census/tracing.c @@ -0,0 +1,45 @@ +/* + * + * Copyright 2015, Google Inc. + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are + * met: + * + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above + * copyright notice, this list of conditions and the following disclaimer + * in the documentation and/or other materials provided with the + * distribution. + * * Neither the name of Google Inc. nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include <grpc/census.h> + +/* TODO(aveitch): These are all placeholder implementations. */ + +int census_trace_mask(const census_context *context) { + return CENSUS_TRACE_MASK_NONE; +} + +void census_set_trace_mask(int trace_mask) {} + +void census_trace_print(census_context *context, gpr_uint32 type, + const char *buffer, size_t n) {} diff --git a/src/core/security/credentials.c b/src/core/security/credentials.c index 362d5f4b6f..a764413300 100644 --- a/src/core/security/credentials.c +++ b/src/core/security/credentials.c @@ -87,7 +87,10 @@ grpc_credentials *grpc_credentials_ref(grpc_credentials *creds) { void grpc_credentials_unref(grpc_credentials *creds) { if (creds == NULL) return; - if (gpr_unref(&creds->refcount)) creds->vtable->destroy(creds); + if (gpr_unref(&creds->refcount)) { + creds->vtable->destruct(creds); + gpr_free(creds); + } } void grpc_credentials_release(grpc_credentials *creds) { @@ -135,9 +138,26 @@ grpc_security_status grpc_credentials_create_security_connector( creds, target, args, request_metadata_creds, sc, new_args); } -void grpc_server_credentials_release(grpc_server_credentials *creds) { +grpc_server_credentials *grpc_server_credentials_ref( + grpc_server_credentials *creds) { + if (creds == NULL) return NULL; + gpr_ref(&creds->refcount); + return creds; +} + +void grpc_server_credentials_unref(grpc_server_credentials *creds) { if (creds == NULL) return; - creds->vtable->destroy(creds); + if (gpr_unref(&creds->refcount)) { + creds->vtable->destruct(creds); + if (creds->processor.destroy != NULL && creds->processor.state != NULL) { + creds->processor.destroy(creds->processor.state); + } + gpr_free(creds); + } +} + +void grpc_server_credentials_release(grpc_server_credentials *creds) { + grpc_server_credentials_unref(creds); } grpc_security_status grpc_server_credentials_create_security_connector( @@ -152,20 +172,22 @@ grpc_security_status grpc_server_credentials_create_security_connector( void grpc_server_credentials_set_auth_metadata_processor( grpc_server_credentials *creds, grpc_auth_metadata_processor processor) { if (creds == NULL) return; + if (creds->processor.destroy != NULL && creds->processor.state != NULL) { + creds->processor.destroy(creds->processor.state); + } creds->processor = processor; } /* -- Ssl credentials. -- */ -static void ssl_destroy(grpc_credentials *creds) { +static void ssl_destruct(grpc_credentials *creds) { grpc_ssl_credentials *c = (grpc_ssl_credentials *)creds; if (c->config.pem_root_certs != NULL) gpr_free(c->config.pem_root_certs); if (c->config.pem_private_key != NULL) gpr_free(c->config.pem_private_key); if (c->config.pem_cert_chain != NULL) gpr_free(c->config.pem_cert_chain); - gpr_free(creds); } -static void ssl_server_destroy(grpc_server_credentials *creds) { +static void ssl_server_destruct(grpc_server_credentials *creds) { grpc_ssl_server_credentials *c = (grpc_ssl_server_credentials *)creds; size_t i; for (i = 0; i < c->config.num_key_cert_pairs; i++) { @@ -185,7 +207,6 @@ static void ssl_server_destroy(grpc_server_credentials *creds) { gpr_free(c->config.pem_cert_chains_sizes); } if (c->config.pem_root_certs != NULL) gpr_free(c->config.pem_root_certs); - gpr_free(creds); } static int ssl_has_request_metadata(const grpc_credentials *creds) { return 0; } @@ -231,11 +252,11 @@ static grpc_security_status ssl_server_create_security_connector( } static grpc_credentials_vtable ssl_vtable = { - ssl_destroy, ssl_has_request_metadata, ssl_has_request_metadata_only, NULL, + ssl_destruct, ssl_has_request_metadata, ssl_has_request_metadata_only, NULL, ssl_create_security_connector}; static grpc_server_credentials_vtable ssl_server_vtable = { - ssl_server_destroy, ssl_server_create_security_connector}; + ssl_server_destruct, ssl_server_create_security_connector}; static void ssl_copy_key_material(const char *input, unsigned char **output, size_t *output_size) { @@ -316,9 +337,9 @@ grpc_server_credentials *grpc_ssl_server_credentials_create( grpc_ssl_server_credentials *c = gpr_malloc(sizeof(grpc_ssl_server_credentials)); GPR_ASSERT(reserved == NULL); - memset(c, 0, sizeof(grpc_ssl_credentials)); memset(c, 0, sizeof(grpc_ssl_server_credentials)); c->base.type = GRPC_CREDENTIALS_TYPE_SSL; + gpr_ref_init(&c->base.refcount, 1); c->base.vtable = &ssl_server_vtable; ssl_build_server_config(pem_root_certs, pem_key_cert_pairs, num_key_cert_pairs, force_client_auth, &c->config); @@ -339,13 +360,12 @@ static void jwt_reset_cache(grpc_service_account_jwt_access_credentials *c) { c->cached.jwt_expiration = gpr_inf_past(GPR_CLOCK_REALTIME); } -static void jwt_destroy(grpc_credentials *creds) { +static void jwt_destruct(grpc_credentials *creds) { grpc_service_account_jwt_access_credentials *c = (grpc_service_account_jwt_access_credentials *)creds; grpc_auth_json_key_destruct(&c->key); jwt_reset_cache(c); gpr_mu_destroy(&c->cache_mu); - gpr_free(c); } static int jwt_has_request_metadata(const grpc_credentials *creds) { return 1; } @@ -410,7 +430,7 @@ static void jwt_get_request_metadata(grpc_credentials *creds, } static grpc_credentials_vtable jwt_vtable = { - jwt_destroy, jwt_has_request_metadata, jwt_has_request_metadata_only, + jwt_destruct, jwt_has_request_metadata, jwt_has_request_metadata_only, jwt_get_request_metadata, NULL}; grpc_credentials * @@ -442,13 +462,12 @@ grpc_credentials *grpc_service_account_jwt_access_credentials_create( /* -- Oauth2TokenFetcher credentials -- */ -static void oauth2_token_fetcher_destroy(grpc_credentials *creds) { +static void oauth2_token_fetcher_destruct(grpc_credentials *creds) { grpc_oauth2_token_fetcher_credentials *c = (grpc_oauth2_token_fetcher_credentials *)creds; grpc_credentials_md_store_unref(c->access_token_md); gpr_mu_destroy(&c->mu); grpc_httpcli_context_destroy(&c->httpcli_context); - gpr_free(c); } static int oauth2_token_fetcher_has_request_metadata( @@ -618,10 +637,10 @@ static void init_oauth2_token_fetcher(grpc_oauth2_token_fetcher_credentials *c, grpc_httpcli_context_init(&c->httpcli_context); } -/* -- ComputeEngine credentials. -- */ +/* -- GoogleComputeEngine credentials. -- */ static grpc_credentials_vtable compute_engine_vtable = { - oauth2_token_fetcher_destroy, oauth2_token_fetcher_has_request_metadata, + oauth2_token_fetcher_destruct, oauth2_token_fetcher_has_request_metadata, oauth2_token_fetcher_has_request_metadata_only, oauth2_token_fetcher_get_request_metadata, NULL}; @@ -640,7 +659,8 @@ static void compute_engine_fetch_oauth2( metadata_req); } -grpc_credentials *grpc_compute_engine_credentials_create(void *reserved) { +grpc_credentials *grpc_google_compute_engine_credentials_create( + void *reserved) { grpc_oauth2_token_fetcher_credentials *c = gpr_malloc(sizeof(grpc_oauth2_token_fetcher_credentials)); GPR_ASSERT(reserved == NULL); @@ -649,87 +669,17 @@ grpc_credentials *grpc_compute_engine_credentials_create(void *reserved) { return &c->base; } -/* -- ServiceAccount credentials. -- */ - -static void service_account_destroy(grpc_credentials *creds) { - grpc_service_account_credentials *c = - (grpc_service_account_credentials *)creds; - if (c->scope != NULL) gpr_free(c->scope); - grpc_auth_json_key_destruct(&c->key); - oauth2_token_fetcher_destroy(&c->base.base); -} - -static grpc_credentials_vtable service_account_vtable = { - service_account_destroy, oauth2_token_fetcher_has_request_metadata, - oauth2_token_fetcher_has_request_metadata_only, - oauth2_token_fetcher_get_request_metadata, NULL}; +/* -- GoogleRefreshToken credentials. -- */ -static void service_account_fetch_oauth2( - grpc_credentials_metadata_request *metadata_req, - grpc_httpcli_context *httpcli_context, grpc_pollset *pollset, - grpc_httpcli_response_cb response_cb, gpr_timespec deadline) { - grpc_service_account_credentials *c = - (grpc_service_account_credentials *)metadata_req->creds; - grpc_httpcli_header header = {"Content-Type", - "application/x-www-form-urlencoded"}; - grpc_httpcli_request request; - char *body = NULL; - char *jwt = grpc_jwt_encode_and_sign(&c->key, GRPC_JWT_OAUTH2_AUDIENCE, - c->token_lifetime, c->scope); - if (jwt == NULL) { - grpc_httpcli_response response; - memset(&response, 0, sizeof(grpc_httpcli_response)); - response.status = 400; /* Invalid request. */ - gpr_log(GPR_ERROR, "Could not create signed jwt."); - /* Do not even send the request, just call the response callback. */ - response_cb(metadata_req, &response); - return; - } - gpr_asprintf(&body, "%s%s", GRPC_SERVICE_ACCOUNT_POST_BODY_PREFIX, jwt); - memset(&request, 0, sizeof(grpc_httpcli_request)); - request.host = GRPC_GOOGLE_OAUTH2_SERVICE_HOST; - request.path = GRPC_GOOGLE_OAUTH2_SERVICE_TOKEN_PATH; - request.hdr_count = 1; - request.hdrs = &header; - request.handshaker = &grpc_httpcli_ssl; - grpc_httpcli_post(httpcli_context, pollset, &request, body, strlen(body), - deadline, response_cb, metadata_req); - gpr_free(body); - gpr_free(jwt); -} - -grpc_credentials *grpc_service_account_credentials_create( - const char *json_key, const char *scope, gpr_timespec token_lifetime, - void *reserved) { - grpc_service_account_credentials *c; - grpc_auth_json_key key = grpc_auth_json_key_create_from_string(json_key); - GPR_ASSERT(reserved == NULL); - if (scope == NULL || (strlen(scope) == 0) || - !grpc_auth_json_key_is_valid(&key)) { - gpr_log(GPR_ERROR, - "Invalid input for service account credentials creation"); - return NULL; - } - c = gpr_malloc(sizeof(grpc_service_account_credentials)); - memset(c, 0, sizeof(grpc_service_account_credentials)); - init_oauth2_token_fetcher(&c->base, service_account_fetch_oauth2); - c->base.base.vtable = &service_account_vtable; - c->scope = gpr_strdup(scope); - c->key = key; - c->token_lifetime = token_lifetime; - return &c->base.base; -} - -/* -- RefreshToken credentials. -- */ - -static void refresh_token_destroy(grpc_credentials *creds) { - grpc_refresh_token_credentials *c = (grpc_refresh_token_credentials *)creds; +static void refresh_token_destruct(grpc_credentials *creds) { + grpc_google_refresh_token_credentials *c = + (grpc_google_refresh_token_credentials *)creds; grpc_auth_refresh_token_destruct(&c->refresh_token); - oauth2_token_fetcher_destroy(&c->base.base); + oauth2_token_fetcher_destruct(&c->base.base); } static grpc_credentials_vtable refresh_token_vtable = { - refresh_token_destroy, oauth2_token_fetcher_has_request_metadata, + refresh_token_destruct, oauth2_token_fetcher_has_request_metadata, oauth2_token_fetcher_has_request_metadata_only, oauth2_token_fetcher_get_request_metadata, NULL}; @@ -737,8 +687,8 @@ static void refresh_token_fetch_oauth2( grpc_credentials_metadata_request *metadata_req, grpc_httpcli_context *httpcli_context, grpc_pollset *pollset, grpc_httpcli_response_cb response_cb, gpr_timespec deadline) { - grpc_refresh_token_credentials *c = - (grpc_refresh_token_credentials *)metadata_req->creds; + grpc_google_refresh_token_credentials *c = + (grpc_google_refresh_token_credentials *)metadata_req->creds; grpc_httpcli_header header = {"Content-Type", "application/x-www-form-urlencoded"}; grpc_httpcli_request request; @@ -757,22 +707,23 @@ static void refresh_token_fetch_oauth2( gpr_free(body); } -grpc_credentials *grpc_refresh_token_credentials_create_from_auth_refresh_token( +grpc_credentials * +grpc_refresh_token_credentials_create_from_auth_refresh_token( grpc_auth_refresh_token refresh_token) { - grpc_refresh_token_credentials *c; + grpc_google_refresh_token_credentials *c; if (!grpc_auth_refresh_token_is_valid(&refresh_token)) { gpr_log(GPR_ERROR, "Invalid input for refresh token credentials creation"); return NULL; } - c = gpr_malloc(sizeof(grpc_refresh_token_credentials)); - memset(c, 0, sizeof(grpc_refresh_token_credentials)); + c = gpr_malloc(sizeof(grpc_google_refresh_token_credentials)); + memset(c, 0, sizeof(grpc_google_refresh_token_credentials)); init_oauth2_token_fetcher(&c->base, refresh_token_fetch_oauth2); c->base.base.vtable = &refresh_token_vtable; c->refresh_token = refresh_token; return &c->base.base; } -grpc_credentials *grpc_refresh_token_credentials_create( +grpc_credentials *grpc_google_refresh_token_credentials_create( const char *json_refresh_token, void *reserved) { GPR_ASSERT(reserved == NULL); return grpc_refresh_token_credentials_create_from_auth_refresh_token( @@ -781,10 +732,9 @@ grpc_credentials *grpc_refresh_token_credentials_create( /* -- Metadata-only credentials. -- */ -static void md_only_test_destroy(grpc_credentials *creds) { +static void md_only_test_destruct(grpc_credentials *creds) { grpc_md_only_test_credentials *c = (grpc_md_only_test_credentials *)creds; grpc_credentials_md_store_unref(c->md_store); - gpr_free(c); } static int md_only_test_has_request_metadata(const grpc_credentials *creds) { @@ -825,7 +775,7 @@ static void md_only_test_get_request_metadata(grpc_credentials *creds, } static grpc_credentials_vtable md_only_test_vtable = { - md_only_test_destroy, md_only_test_has_request_metadata, + md_only_test_destruct, md_only_test_has_request_metadata, md_only_test_has_request_metadata_only, md_only_test_get_request_metadata, NULL}; @@ -846,10 +796,9 @@ grpc_credentials *grpc_md_only_test_credentials_create(const char *md_key, /* -- Oauth2 Access Token credentials. -- */ -static void access_token_destroy(grpc_credentials *creds) { +static void access_token_destruct(grpc_credentials *creds) { grpc_access_token_credentials *c = (grpc_access_token_credentials *)creds; grpc_credentials_md_store_unref(c->access_token_md); - gpr_free(c); } static int access_token_has_request_metadata(const grpc_credentials *creds) { @@ -871,7 +820,7 @@ static void access_token_get_request_metadata(grpc_credentials *creds, } static grpc_credentials_vtable access_token_vtable = { - access_token_destroy, access_token_has_request_metadata, + access_token_destruct, access_token_has_request_metadata, access_token_has_request_metadata_only, access_token_get_request_metadata, NULL}; @@ -895,14 +844,14 @@ grpc_credentials *grpc_access_token_credentials_create(const char *access_token, /* -- Fake transport security credentials. -- */ -static void fake_transport_security_credentials_destroy( +static void fake_transport_security_credentials_destruct( grpc_credentials *creds) { - gpr_free(creds); + /* Nothing to do here. */ } -static void fake_transport_security_server_credentials_destroy( +static void fake_transport_security_server_credentials_destruct( grpc_server_credentials *creds) { - gpr_free(creds); + /* Nothing to do here. */ } static int fake_transport_security_has_request_metadata( @@ -931,14 +880,14 @@ fake_transport_security_server_create_security_connector( } static grpc_credentials_vtable fake_transport_security_credentials_vtable = { - fake_transport_security_credentials_destroy, + fake_transport_security_credentials_destruct, fake_transport_security_has_request_metadata, fake_transport_security_has_request_metadata_only, NULL, fake_transport_security_create_security_connector}; static grpc_server_credentials_vtable fake_transport_security_server_credentials_vtable = { - fake_transport_security_server_credentials_destroy, + fake_transport_security_server_credentials_destruct, fake_transport_security_server_create_security_connector}; grpc_credentials *grpc_fake_transport_security_credentials_create(void) { @@ -955,6 +904,7 @@ grpc_server_credentials *grpc_fake_transport_security_server_credentials_create( grpc_server_credentials *c = gpr_malloc(sizeof(grpc_server_credentials)); memset(c, 0, sizeof(grpc_server_credentials)); c->type = GRPC_CREDENTIALS_TYPE_FAKE_TRANSPORT_SECURITY; + gpr_ref_init(&c->refcount, 1); c->vtable = &fake_transport_security_server_credentials_vtable; return c; } @@ -971,14 +921,13 @@ typedef struct { grpc_credentials_metadata_cb cb; } grpc_composite_credentials_metadata_context; -static void composite_destroy(grpc_credentials *creds) { +static void composite_destruct(grpc_credentials *creds) { grpc_composite_credentials *c = (grpc_composite_credentials *)creds; size_t i; for (i = 0; i < c->inner.num_creds; i++) { grpc_credentials_unref(c->inner.creds_array[i]); } gpr_free(c->inner.creds_array); - gpr_free(creds); } static int composite_has_request_metadata(const grpc_credentials *creds) { @@ -1094,7 +1043,7 @@ static grpc_security_status composite_create_security_connector( } static grpc_credentials_vtable composite_credentials_vtable = { - composite_destroy, composite_has_request_metadata, + composite_destruct, composite_has_request_metadata, composite_has_request_metadata_only, composite_get_request_metadata, composite_create_security_connector}; @@ -1193,10 +1142,9 @@ grpc_credentials *grpc_credentials_contains_type( /* -- IAM credentials. -- */ -static void iam_destroy(grpc_credentials *creds) { - grpc_iam_credentials *c = (grpc_iam_credentials *)creds; +static void iam_destruct(grpc_credentials *creds) { + grpc_google_iam_credentials *c = (grpc_google_iam_credentials *)creds; grpc_credentials_md_store_unref(c->iam_md); - gpr_free(c); } static int iam_has_request_metadata(const grpc_credentials *creds) { return 1; } @@ -1210,24 +1158,23 @@ static void iam_get_request_metadata(grpc_credentials *creds, const char *service_url, grpc_credentials_metadata_cb cb, void *user_data) { - grpc_iam_credentials *c = (grpc_iam_credentials *)creds; + grpc_google_iam_credentials *c = (grpc_google_iam_credentials *)creds; cb(user_data, c->iam_md->entries, c->iam_md->num_entries, GRPC_CREDENTIALS_OK); } static grpc_credentials_vtable iam_vtable = { - iam_destroy, iam_has_request_metadata, iam_has_request_metadata_only, + iam_destruct, iam_has_request_metadata, iam_has_request_metadata_only, iam_get_request_metadata, NULL}; -grpc_credentials *grpc_iam_credentials_create(const char *token, - const char *authority_selector, - void *reserved) { - grpc_iam_credentials *c; +grpc_credentials *grpc_google_iam_credentials_create( + const char *token, const char *authority_selector, void *reserved) { + grpc_google_iam_credentials *c; GPR_ASSERT(reserved == NULL); GPR_ASSERT(token != NULL); GPR_ASSERT(authority_selector != NULL); - c = gpr_malloc(sizeof(grpc_iam_credentials)); - memset(c, 0, sizeof(grpc_iam_credentials)); + c = gpr_malloc(sizeof(grpc_google_iam_credentials)); + memset(c, 0, sizeof(grpc_google_iam_credentials)); c->base.type = GRPC_CREDENTIALS_TYPE_IAM; c->base.vtable = &iam_vtable; gpr_ref_init(&c->base.refcount, 1); diff --git a/src/core/security/credentials.h b/src/core/security/credentials.h index 29cd1ac87f..8e4fed7615 100644 --- a/src/core/security/credentials.h +++ b/src/core/security/credentials.h @@ -129,7 +129,7 @@ typedef void (*grpc_credentials_metadata_cb)(void *user_data, grpc_credentials_status status); typedef struct { - void (*destroy)(grpc_credentials *c); + void (*destruct)(grpc_credentials *c); int (*has_request_metadata)(const grpc_credentials *c); int (*has_request_metadata_only)(const grpc_credentials *c); void (*get_request_metadata)(grpc_credentials *c, grpc_pollset *pollset, @@ -210,20 +210,28 @@ grpc_credentials *grpc_refresh_token_credentials_create_from_auth_refresh_token( /* --- grpc_server_credentials. --- */ typedef struct { - void (*destroy)(grpc_server_credentials *c); + void (*destruct)(grpc_server_credentials *c); grpc_security_status (*create_security_connector)( grpc_server_credentials *c, grpc_security_connector **sc); } grpc_server_credentials_vtable; + +/* TODO(jboeuf): Add a refcount. */ struct grpc_server_credentials { const grpc_server_credentials_vtable *vtable; const char *type; + gpr_refcount refcount; grpc_auth_metadata_processor processor; }; grpc_security_status grpc_server_credentials_create_security_connector( grpc_server_credentials *creds, grpc_security_connector **sc); +grpc_server_credentials *grpc_server_credentials_ref( + grpc_server_credentials *creds); + +void grpc_server_credentials_unref(grpc_server_credentials *creds); + /* -- Ssl credentials. -- */ typedef struct { @@ -277,21 +285,12 @@ typedef struct { grpc_fetch_oauth2_func fetch_func; } grpc_oauth2_token_fetcher_credentials; -/* -- ServiceAccount credentials. -- */ - -typedef struct { - grpc_oauth2_token_fetcher_credentials base; - grpc_auth_json_key key; - char *scope; - gpr_timespec token_lifetime; -} grpc_service_account_credentials; - -/* -- RefreshToken credentials. -- */ +/* -- GoogleRefreshToken credentials. -- */ typedef struct { grpc_oauth2_token_fetcher_credentials base; grpc_auth_refresh_token refresh_token; -} grpc_refresh_token_credentials; +} grpc_google_refresh_token_credentials; /* -- Oauth2 Access Token credentials. -- */ @@ -308,12 +307,12 @@ typedef struct { int is_async; } grpc_md_only_test_credentials; -/* -- IAM credentials. -- */ +/* -- GoogleIAM credentials. -- */ typedef struct { grpc_credentials base; grpc_credentials_md_store *iam_md; -} grpc_iam_credentials; +} grpc_google_iam_credentials; /* -- Composite credentials. -- */ diff --git a/src/core/security/google_default_credentials.c b/src/core/security/google_default_credentials.c index f9aa5187ce..874dd59e84 100644 --- a/src/core/security/google_default_credentials.c +++ b/src/core/security/google_default_credentials.c @@ -194,7 +194,7 @@ grpc_credentials *grpc_google_default_credentials_create(void) { int need_compute_engine_creds = is_stack_running_on_compute_engine(); compute_engine_detection_done = 1; if (need_compute_engine_creds) { - result = grpc_compute_engine_credentials_create(NULL); + result = grpc_google_compute_engine_credentials_create(NULL); } } diff --git a/src/core/security/security_context.c b/src/core/security/security_context.c index c1b434f302..95d80ba122 100644 --- a/src/core/security/security_context.c +++ b/src/core/security/security_context.c @@ -42,19 +42,6 @@ #include <grpc/support/log.h> #include <grpc/support/string_util.h> -/* --- grpc_process_auth_metadata_func --- */ - -static grpc_auth_metadata_processor server_processor = {NULL, NULL}; - -grpc_auth_metadata_processor grpc_server_get_auth_metadata_processor(void) { - return server_processor; -} - -void grpc_server_register_auth_metadata_processor( - grpc_auth_metadata_processor processor) { - server_processor = processor; -} - /* --- grpc_call --- */ grpc_call_error grpc_call_set_credentials(grpc_call *call, diff --git a/src/core/security/server_auth_filter.c b/src/core/security/server_auth_filter.c index 6e831431fa..b767f85498 100644 --- a/src/core/security/server_auth_filter.c +++ b/src/core/security/server_auth_filter.c @@ -50,6 +50,7 @@ typedef struct call_data { handling it. */ grpc_iomgr_closure auth_on_recv; grpc_transport_stream_op transport_op; + grpc_metadata_array md; const grpc_metadata *consumed_md; size_t num_consumed_md; grpc_stream_op *md_op; @@ -90,13 +91,17 @@ static grpc_mdelem *remove_consumed_md(void *user_data, grpc_mdelem *md) { call_data *calld = elem->call_data; size_t i; for (i = 0; i < calld->num_consumed_md; i++) { + const grpc_metadata *consumed_md = &calld->consumed_md[i]; /* Maybe we could do a pointer comparison but we do not have any guarantee that the metadata processor used the same pointers for consumed_md in the callback. */ - if (memcmp(GPR_SLICE_START_PTR(md->key->slice), calld->consumed_md[i].key, + if (GPR_SLICE_LENGTH(md->key->slice) != strlen(consumed_md->key) || + GPR_SLICE_LENGTH(md->value->slice) != consumed_md->value_length) { + continue; + } + if (memcmp(GPR_SLICE_START_PTR(md->key->slice), consumed_md->key, GPR_SLICE_LENGTH(md->key->slice)) == 0 && - memcmp(GPR_SLICE_START_PTR(md->value->slice), - calld->consumed_md[i].value, + memcmp(GPR_SLICE_START_PTR(md->value->slice), consumed_md->value, GPR_SLICE_LENGTH(md->value->slice)) == 0) { return NULL; /* Delete. */ } @@ -134,6 +139,7 @@ static void on_md_processing_done( grpc_transport_stream_op_add_close(&calld->transport_op, status, &message); grpc_call_next_op(elem, &calld->transport_op); } + grpc_metadata_array_destroy(&calld->md); } static void auth_on_recv(void *user_data, int success) { @@ -145,17 +151,15 @@ static void auth_on_recv(void *user_data, int success) { size_t nops = calld->recv_ops->nops; grpc_stream_op *ops = calld->recv_ops->ops; for (i = 0; i < nops; i++) { - grpc_metadata_array md_array; grpc_stream_op *op = &ops[i]; if (op->type != GRPC_OP_METADATA || calld->got_client_metadata) continue; calld->got_client_metadata = 1; if (chand->processor.process == NULL) continue; calld->md_op = op; - md_array = metadata_batch_to_md_array(&op->data.metadata); + calld->md = metadata_batch_to_md_array(&op->data.metadata); chand->processor.process(chand->processor.state, calld->auth_context, - md_array.metadata, md_array.count, + calld->md.metadata, calld->md.count, on_md_processing_done, elem); - grpc_metadata_array_destroy(&md_array); return; } } diff --git a/src/core/security/server_secure_chttp2.c b/src/core/security/server_secure_chttp2.c index 8d9d036d80..4749f5f516 100644 --- a/src/core/security/server_secure_chttp2.c +++ b/src/core/security/server_secure_chttp2.c @@ -61,7 +61,7 @@ typedef struct grpc_server_secure_state { grpc_server *server; grpc_tcp_server *tcp; grpc_security_connector *sc; - grpc_auth_metadata_processor processor; + grpc_server_credentials *creds; tcp_endpoint_list *handshaking_tcp_endpoints; int is_shutdown; gpr_mu mu; @@ -79,6 +79,7 @@ static void state_unref(grpc_server_secure_state *state) { gpr_mu_unlock(&state->mu); /* clean up */ GRPC_SECURITY_CONNECTOR_UNREF(state->sc, "server"); + grpc_server_credentials_unref(state->creds); gpr_free(state); } } @@ -91,7 +92,8 @@ static void setup_transport(void *statep, grpc_transport *transport, grpc_channel_args *args_copy; grpc_arg args_to_add[2]; args_to_add[0] = grpc_security_connector_to_arg(state->sc); - args_to_add[1] = grpc_auth_metadata_processor_to_arg(&state->processor); + args_to_add[1] = + grpc_auth_metadata_processor_to_arg(&state->creds->processor); args_copy = grpc_channel_args_copy_and_add( grpc_server_get_channel_args(state->server), args_to_add, GPR_ARRAY_SIZE(args_to_add)); @@ -262,7 +264,8 @@ int grpc_server_add_secure_http2_port(grpc_server *server, const char *addr, state->server = server; state->tcp = tcp; state->sc = sc; - state->processor = creds->processor; + state->creds = grpc_server_credentials_ref(creds); + state->handshaking_tcp_endpoints = NULL; state->is_shutdown = 0; gpr_mu_init(&state->mu); diff --git a/src/cpp/client/channel.cc b/src/cpp/client/channel.cc index 8bf2e4687e..dc8e304664 100644 --- a/src/cpp/client/channel.cc +++ b/src/cpp/client/channel.cc @@ -40,7 +40,7 @@ #include <grpc/support/slice.h> #include <grpc++/client_context.h> #include <grpc++/completion_queue.h> -#include <grpc++/credentials.h> +#include <grpc++/security/credentials.h> #include <grpc++/impl/call.h> #include <grpc++/impl/rpc_method.h> #include <grpc++/support/channel_arguments.h> diff --git a/src/cpp/client/client_context.cc b/src/cpp/client/client_context.cc index c4d7cf2e51..574656a7e9 100644 --- a/src/cpp/client/client_context.cc +++ b/src/cpp/client/client_context.cc @@ -36,7 +36,7 @@ #include <grpc/grpc.h> #include <grpc/support/alloc.h> #include <grpc/support/string_util.h> -#include <grpc++/credentials.h> +#include <grpc++/security/credentials.h> #include <grpc++/server_context.h> #include <grpc++/support/time.h> diff --git a/src/cpp/client/create_channel.cc b/src/cpp/client/create_channel.cc index 1dac960017..d2b2d30126 100644 --- a/src/cpp/client/create_channel.cc +++ b/src/cpp/client/create_channel.cc @@ -51,6 +51,7 @@ std::shared_ptr<Channel> CreateChannel( std::shared_ptr<Channel> CreateCustomChannel( const grpc::string& target, const std::shared_ptr<Credentials>& creds, const ChannelArguments& args) { + GrpcLibrary init_lib; // We need to call init in case of a bad creds. ChannelArguments cp_args = args; std::ostringstream user_agent_prefix; user_agent_prefix << "grpc-c++/" << grpc_version_string(); diff --git a/src/cpp/client/credentials.cc b/src/cpp/client/credentials.cc index e806284988..7a8149e9c7 100644 --- a/src/cpp/client/credentials.cc +++ b/src/cpp/client/credentials.cc @@ -31,7 +31,7 @@ * */ -#include <grpc++/credentials.h> +#include <grpc++/security/credentials.h> namespace grpc { diff --git a/src/cpp/client/insecure_credentials.cc b/src/cpp/client/insecure_credentials.cc index 4a4d2cb97d..c476f3ce95 100644 --- a/src/cpp/client/insecure_credentials.cc +++ b/src/cpp/client/insecure_credentials.cc @@ -31,7 +31,7 @@ * */ -#include <grpc++/credentials.h> +#include <grpc++/security/credentials.h> #include <grpc/grpc.h> #include <grpc/support/log.h> diff --git a/src/cpp/client/secure_credentials.cc b/src/cpp/client/secure_credentials.cc index e0642469b4..2260f6d33e 100644 --- a/src/cpp/client/secure_credentials.cc +++ b/src/cpp/client/secure_credentials.cc @@ -81,26 +81,10 @@ std::shared_ptr<Credentials> SslCredentials( } // Builds credentials for use when running in GCE -std::shared_ptr<Credentials> ComputeEngineCredentials() { +std::shared_ptr<Credentials> GoogleComputeEngineCredentials() { GrpcLibrary init; // To call grpc_init(). - return WrapCredentials(grpc_compute_engine_credentials_create(nullptr)); -} - -// Builds service account credentials. -std::shared_ptr<Credentials> ServiceAccountCredentials( - const grpc::string& json_key, const grpc::string& scope, - long token_lifetime_seconds) { - GrpcLibrary init; // To call grpc_init(). - if (token_lifetime_seconds <= 0) { - gpr_log(GPR_ERROR, - "Trying to create ServiceAccountCredentials " - "with non-positive lifetime"); - return WrapCredentials(nullptr); - } - gpr_timespec lifetime = - gpr_time_from_seconds(token_lifetime_seconds, GPR_TIMESPAN); - return WrapCredentials(grpc_service_account_credentials_create( - json_key.c_str(), scope.c_str(), lifetime, nullptr)); + return WrapCredentials( + grpc_google_compute_engine_credentials_create(nullptr)); } // Builds JWT credentials. @@ -119,10 +103,10 @@ std::shared_ptr<Credentials> ServiceAccountJWTAccessCredentials( } // Builds refresh token credentials. -std::shared_ptr<Credentials> RefreshTokenCredentials( +std::shared_ptr<Credentials> GoogleRefreshTokenCredentials( const grpc::string& json_refresh_token) { GrpcLibrary init; // To call grpc_init(). - return WrapCredentials(grpc_refresh_token_credentials_create( + return WrapCredentials(grpc_google_refresh_token_credentials_create( json_refresh_token.c_str(), nullptr)); } @@ -135,11 +119,11 @@ std::shared_ptr<Credentials> AccessTokenCredentials( } // Builds IAM credentials. -std::shared_ptr<Credentials> IAMCredentials( +std::shared_ptr<Credentials> GoogleIAMCredentials( const grpc::string& authorization_token, const grpc::string& authority_selector) { GrpcLibrary init; // To call grpc_init(). - return WrapCredentials(grpc_iam_credentials_create( + return WrapCredentials(grpc_google_iam_credentials_create( authorization_token.c_str(), authority_selector.c_str(), nullptr)); } diff --git a/src/cpp/client/secure_credentials.h b/src/cpp/client/secure_credentials.h index 62d3185477..8deff856c4 100644 --- a/src/cpp/client/secure_credentials.h +++ b/src/cpp/client/secure_credentials.h @@ -37,7 +37,7 @@ #include <grpc/grpc_security.h> #include <grpc++/support/config.h> -#include <grpc++/credentials.h> +#include <grpc++/security/credentials.h> namespace grpc { diff --git a/src/cpp/common/auth_property_iterator.cc b/src/cpp/common/auth_property_iterator.cc index fa6da9d7a8..a47abaf4b8 100644 --- a/src/cpp/common/auth_property_iterator.cc +++ b/src/cpp/common/auth_property_iterator.cc @@ -31,7 +31,7 @@ * */ -#include <grpc++/support/auth_context.h> +#include <grpc++/security/auth_context.h> #include <grpc/grpc_security.h> diff --git a/src/cpp/common/create_auth_context.h b/src/cpp/common/create_auth_context.h index b4962bae4e..4f3da397ba 100644 --- a/src/cpp/common/create_auth_context.h +++ b/src/cpp/common/create_auth_context.h @@ -33,7 +33,7 @@ #include <memory> #include <grpc/grpc.h> -#include <grpc++/support/auth_context.h> +#include <grpc++/security/auth_context.h> namespace grpc { diff --git a/src/cpp/common/insecure_create_auth_context.cc b/src/cpp/common/insecure_create_auth_context.cc index fe80c1a80c..b2e153229a 100644 --- a/src/cpp/common/insecure_create_auth_context.cc +++ b/src/cpp/common/insecure_create_auth_context.cc @@ -33,7 +33,7 @@ #include <memory> #include <grpc/grpc.h> -#include <grpc++/support/auth_context.h> +#include <grpc++/security/auth_context.h> namespace grpc { diff --git a/src/cpp/common/secure_auth_context.cc b/src/cpp/common/secure_auth_context.cc index b18a8537c9..8615ac8aeb 100644 --- a/src/cpp/common/secure_auth_context.cc +++ b/src/cpp/common/secure_auth_context.cc @@ -37,9 +37,13 @@ namespace grpc { -SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx) : ctx_(ctx) {} +SecureAuthContext::SecureAuthContext(grpc_auth_context* ctx, + bool take_ownership) + : ctx_(ctx), take_ownership_(take_ownership) {} -SecureAuthContext::~SecureAuthContext() { grpc_auth_context_release(ctx_); } +SecureAuthContext::~SecureAuthContext() { + if (take_ownership_) grpc_auth_context_release(ctx_); +} std::vector<grpc::string_ref> SecureAuthContext::GetPeerIdentity() const { if (!ctx_) { @@ -94,4 +98,21 @@ AuthPropertyIterator SecureAuthContext::end() const { return AuthPropertyIterator(); } +void SecureAuthContext::AddProperty(const grpc::string& key, + const grpc::string_ref& value) { + if (!ctx_) return; + grpc_auth_context_add_property(ctx_, key.c_str(), value.data(), value.size()); +} + +bool SecureAuthContext::SetPeerIdentityPropertyName(const grpc::string& name) { + if (!ctx_) return false; + return grpc_auth_context_set_peer_identity_property_name(ctx_, + name.c_str()) != 0; +} + +bool SecureAuthContext::IsPeerAuthenticated() const { + if (!ctx_) return false; + return grpc_auth_context_peer_is_authenticated(ctx_) != 0; +} + } // namespace grpc diff --git a/src/cpp/common/secure_auth_context.h b/src/cpp/common/secure_auth_context.h index 7f622b890b..c9f1dad131 100644 --- a/src/cpp/common/secure_auth_context.h +++ b/src/cpp/common/secure_auth_context.h @@ -34,7 +34,7 @@ #ifndef GRPC_INTERNAL_CPP_COMMON_SECURE_AUTH_CONTEXT_H #define GRPC_INTERNAL_CPP_COMMON_SECURE_AUTH_CONTEXT_H -#include <grpc++/support/auth_context.h> +#include <grpc++/security/auth_context.h> struct grpc_auth_context; @@ -42,10 +42,12 @@ namespace grpc { class SecureAuthContext GRPC_FINAL : public AuthContext { public: - SecureAuthContext(grpc_auth_context* ctx); + SecureAuthContext(grpc_auth_context* ctx, bool take_ownership); ~SecureAuthContext() GRPC_OVERRIDE; + bool IsPeerAuthenticated() const GRPC_OVERRIDE; + std::vector<grpc::string_ref> GetPeerIdentity() const GRPC_OVERRIDE; grpc::string GetPeerIdentityPropertyName() const GRPC_OVERRIDE; @@ -57,8 +59,15 @@ class SecureAuthContext GRPC_FINAL : public AuthContext { AuthPropertyIterator end() const GRPC_OVERRIDE; + void AddProperty(const grpc::string& key, + const grpc::string_ref& value) GRPC_OVERRIDE; + + virtual bool SetPeerIdentityPropertyName(const grpc::string& name) + GRPC_OVERRIDE; + private: grpc_auth_context* ctx_; + bool take_ownership_; }; } // namespace grpc diff --git a/src/cpp/common/secure_create_auth_context.cc b/src/cpp/common/secure_create_auth_context.cc index f13d25a1dd..40bc298b64 100644 --- a/src/cpp/common/secure_create_auth_context.cc +++ b/src/cpp/common/secure_create_auth_context.cc @@ -34,7 +34,7 @@ #include <grpc/grpc.h> #include <grpc/grpc_security.h> -#include <grpc++/support/auth_context.h> +#include <grpc++/security/auth_context.h> #include "src/cpp/common/secure_auth_context.h" namespace grpc { @@ -44,7 +44,7 @@ std::shared_ptr<const AuthContext> CreateAuthContext(grpc_call* call) { return std::shared_ptr<const AuthContext>(); } return std::shared_ptr<const AuthContext>( - new SecureAuthContext(grpc_call_auth_context(call))); + new SecureAuthContext(grpc_call_auth_context(call), true)); } } // namespace grpc diff --git a/src/cpp/server/insecure_server_credentials.cc b/src/cpp/server/insecure_server_credentials.cc index 800cd36caa..ef3cae5fd7 100644 --- a/src/cpp/server/insecure_server_credentials.cc +++ b/src/cpp/server/insecure_server_credentials.cc @@ -31,9 +31,10 @@ * */ -#include <grpc++/server_credentials.h> +#include <grpc++/security/server_credentials.h> #include <grpc/grpc.h> +#include <grpc/support/log.h> namespace grpc { namespace { @@ -43,6 +44,11 @@ class InsecureServerCredentialsImpl GRPC_FINAL : public ServerCredentials { grpc_server* server) GRPC_OVERRIDE { return grpc_server_add_insecure_http2_port(server, addr.c_str()); } + void SetAuthMetadataProcessor( + const std::shared_ptr<AuthMetadataProcessor>& processor) GRPC_OVERRIDE { + (void)processor; + GPR_ASSERT(0); // Should not be called on InsecureServerCredentials. + } }; } // namespace diff --git a/src/cpp/server/secure_server_credentials.cc b/src/cpp/server/secure_server_credentials.cc index 5bce9ca8b2..dfa9229c98 100644 --- a/src/cpp/server/secure_server_credentials.cc +++ b/src/cpp/server/secure_server_credentials.cc @@ -31,15 +31,94 @@ * */ +#include <functional> +#include <map> +#include <memory> + + +#include "src/cpp/common/secure_auth_context.h" #include "src/cpp/server/secure_server_credentials.h" +#include <grpc++/security/auth_metadata_processor.h> + namespace grpc { +void AuthMetadataProcessorAyncWrapper::Destroy(void *wrapper) { + auto* w = reinterpret_cast<AuthMetadataProcessorAyncWrapper*>(wrapper); + delete w; +} + +void AuthMetadataProcessorAyncWrapper::Process( + void* wrapper, grpc_auth_context* context, const grpc_metadata* md, + size_t num_md, grpc_process_auth_metadata_done_cb cb, void* user_data) { + auto* w = reinterpret_cast<AuthMetadataProcessorAyncWrapper*>(wrapper); + if (w->processor_ == nullptr) { + // Early exit. + cb(user_data, nullptr, 0, nullptr, 0, GRPC_STATUS_OK, nullptr); + return; + } + if (w->processor_->IsBlocking()) { + w->thread_pool_->Add( + std::bind(&AuthMetadataProcessorAyncWrapper::InvokeProcessor, w, + context, md, num_md, cb, user_data)); + } else { + // invoke directly. + w->InvokeProcessor(context, md, num_md, cb, user_data); + } +} + +void AuthMetadataProcessorAyncWrapper::InvokeProcessor( + grpc_auth_context* ctx, + const grpc_metadata* md, size_t num_md, + grpc_process_auth_metadata_done_cb cb, void* user_data) { + AuthMetadataProcessor::InputMetadata metadata; + for (size_t i = 0; i < num_md; i++) { + metadata.insert(std::make_pair( + md[i].key, grpc::string_ref(md[i].value, md[i].value_length))); + } + SecureAuthContext context(ctx, false); + AuthMetadataProcessor::OutputMetadata consumed_metadata; + AuthMetadataProcessor::OutputMetadata response_metadata; + + Status status = processor_->Process(metadata, &context, &consumed_metadata, + &response_metadata); + + std::vector<grpc_metadata> consumed_md; + for (auto it = consumed_metadata.begin(); it != consumed_metadata.end(); + ++it) { + consumed_md.push_back({it->first.c_str(), + it->second.data(), + it->second.size(), + 0, + {{nullptr, nullptr, nullptr, nullptr}}}); + } + std::vector<grpc_metadata> response_md; + for (auto it = response_metadata.begin(); it != response_metadata.end(); + ++it) { + response_md.push_back({it->first.c_str(), + it->second.data(), + it->second.size(), + 0, + {{nullptr, nullptr, nullptr, nullptr}}}); + } + cb(user_data, &consumed_md[0], consumed_md.size(), &response_md[0], + response_md.size(), static_cast<grpc_status_code>(status.error_code()), + status.error_message().c_str()); +} + int SecureServerCredentials::AddPortToServer(const grpc::string& addr, grpc_server* server) { return grpc_server_add_secure_http2_port(server, addr.c_str(), creds_); } +void SecureServerCredentials::SetAuthMetadataProcessor( + const std::shared_ptr<AuthMetadataProcessor>& processor) { + auto *wrapper = new AuthMetadataProcessorAyncWrapper(processor); + grpc_server_credentials_set_auth_metadata_processor( + creds_, {AuthMetadataProcessorAyncWrapper::Process, + AuthMetadataProcessorAyncWrapper::Destroy, wrapper}); +} + std::shared_ptr<ServerCredentials> SslServerCredentials( const SslServerCredentialsOptions& options) { std::vector<grpc_ssl_pem_key_cert_pair> pem_key_cert_pairs; diff --git a/src/cpp/server/secure_server_credentials.h b/src/cpp/server/secure_server_credentials.h index d3d37b188d..4f003c6b7e 100644 --- a/src/cpp/server/secure_server_credentials.h +++ b/src/cpp/server/secure_server_credentials.h @@ -34,12 +34,36 @@ #ifndef GRPC_INTERNAL_CPP_SERVER_SECURE_SERVER_CREDENTIALS_H #define GRPC_INTERNAL_CPP_SERVER_SECURE_SERVER_CREDENTIALS_H -#include <grpc++/server_credentials.h> +#include <memory> + +#include <grpc++/security/server_credentials.h> #include <grpc/grpc_security.h> +#include "src/cpp/server/thread_pool_interface.h" + namespace grpc { +class AuthMetadataProcessorAyncWrapper GRPC_FINAL { + public: + static void Destroy(void *wrapper); + + static void Process(void* wrapper, grpc_auth_context* context, + const grpc_metadata* md, size_t num_md, + grpc_process_auth_metadata_done_cb cb, void* user_data); + + AuthMetadataProcessorAyncWrapper( + const std::shared_ptr<AuthMetadataProcessor>& processor) + : thread_pool_(CreateDefaultThreadPool()), processor_(processor) {} + + private: + void InvokeProcessor(grpc_auth_context* context, const grpc_metadata* md, + size_t num_md, grpc_process_auth_metadata_done_cb cb, + void* user_data); + std::unique_ptr<ThreadPoolInterface> thread_pool_; + std::shared_ptr<AuthMetadataProcessor> processor_; +}; + class SecureServerCredentials GRPC_FINAL : public ServerCredentials { public: explicit SecureServerCredentials(grpc_server_credentials* creds) @@ -51,8 +75,12 @@ class SecureServerCredentials GRPC_FINAL : public ServerCredentials { int AddPortToServer(const grpc::string& addr, grpc_server* server) GRPC_OVERRIDE; + void SetAuthMetadataProcessor( + const std::shared_ptr<AuthMetadataProcessor>& processor) GRPC_OVERRIDE; + private: - grpc_server_credentials* const creds_; + grpc_server_credentials* creds_; + std::unique_ptr<AuthMetadataProcessorAyncWrapper> processor_; }; } // namespace grpc diff --git a/src/cpp/server/server.cc b/src/cpp/server/server.cc index bb83c7d887..d67205e822 100644 --- a/src/cpp/server/server.cc +++ b/src/cpp/server/server.cc @@ -43,7 +43,7 @@ #include <grpc++/impl/rpc_service_method.h> #include <grpc++/impl/service_type.h> #include <grpc++/server_context.h> -#include <grpc++/server_credentials.h> +#include <grpc++/security/server_credentials.h> #include <grpc++/support/time.h> #include "src/core/profiling/timers.h" @@ -354,7 +354,7 @@ bool Server::Start(ServerCompletionQueue** cqs, size_t num_cqs) { unknown_method_.reset(new RpcServiceMethod( "unknown", RpcMethod::BIDI_STREAMING, new UnknownMethodHandler)); // Use of emplace_back with just constructor arguments is not accepted - // here by gcc-4.4 because it can't match the anonymous nullptr with a + // here by gcc-4.4 because it can't match the anonymous nullptr with a // proper constructor implicitly. Construct the object and use push_back. sync_methods_->push_back(SyncRequest(unknown_method_.get(), nullptr)); } @@ -384,7 +384,7 @@ void Server::ShutdownInternal(gpr_timespec deadline) { // Spin, eating requests until the completion queue is completely shutdown. // If the deadline expires then cancel anything that's pending and keep // spinning forever until the work is actually drained. - // Since nothing else needs to touch state guarded by mu_, holding it + // Since nothing else needs to touch state guarded by mu_, holding it // through this loop is fine. SyncRequest* request; bool ok; diff --git a/src/cpp/server/server_credentials.cc b/src/cpp/server/server_credentials.cc index be3a7425e0..8495916178 100644 --- a/src/cpp/server/server_credentials.cc +++ b/src/cpp/server/server_credentials.cc @@ -31,7 +31,7 @@ * */ -#include <grpc++/server_credentials.h> +#include <grpc++/security/server_credentials.h> namespace grpc { diff --git a/src/csharp/README.md b/src/csharp/README.md index 30523b3bd2..3fbc1c5f05 100644 --- a/src/csharp/README.md +++ b/src/csharp/README.md @@ -158,3 +158,20 @@ Contents An example client that sends some requests to math server. - Grpc.IntegrationTesting: Cross-language gRPC implementation testing (interop testing). + +Troubleshooting +--------------- + +### Problem: Unable to load DLL 'grpc_csharp_ext.dll' + +Internally, gRPC C# uses a native library written in C (gRPC C core) and invokes its functionality via P/Invoke. `grpc_csharp_ext` library is a native extension library that facilitates this by wrapping some C core API into a form that's more digestible for P/Invoke. If you get the above error, it means that the native dependencies could not be located by the C# runtime (or they are incompatible with the current runtime, so they could not be loaded). The solution to this is environment specific. + +- If you are developing on Windows in Visual Studio, the `grpc_csharp_ext.dll` that is shipped by gRPC nuget packages should be automatically copied to your build destination folder once you build. By adjusting project properties in your VS project file, you can influence which exact configuration of `grpc_csharp_ext.dll` will be used (based on VS version, bitness, debug/release configuration). + +- If you are running your application that is using gRPC on Windows machine that doesn't have Visual Studio installed, you might need to install [Visual C++ 2013 redistributable](https://www.microsoft.com/en-us/download/details.aspx?id=40784) that contains some system .dll libraries that `grpc_csharp_ext.dll` depends on (see #905 for more details). + +- On Linux (or Docker), you need to first install gRPC C core and `libgrpc_csharp_ext.so` shared libraries. Currently, the libraries can be installed by `make install_grpc_csharp_ext` or using Linuxbrew (a Debian package is coming soon). Installation on a machine where your application is going to be deployed is no different. + +- On Mac, you need to first install gRPC C core and `libgrpc_csharp_ext.dylib` shared libraries using Homebrew. See above for installation instruction. Installation on a machine where your application is going to be deployed is no different. + +- Possible cause for the problem is that the `grpc_csharp_ext` library is installed, but it has different bitness (32/64bit) than your C# runtime (in case you are using mono) or C# application. diff --git a/src/node/README.md b/src/node/README.md index b6411537c7..c96bc96642 100644 --- a/src/node/README.md +++ b/src/node/README.md @@ -11,10 +11,10 @@ Alpha : Ready for early adopters **Linux (Debian):** -Add [Debian unstable][] to your `sources.list` file. Example: +Add [Debian testing][] to your `sources.list` file. Example: ```sh -echo "deb http://ftp.us.debian.org/debian unstable main contrib non-free" | \ +echo "deb http://ftp.us.debian.org/debian testing main contrib non-free" | \ sudo tee -a /etc/apt/sources.list ``` @@ -113,4 +113,4 @@ An object with factory methods for creating credential objects for servers. [homebrew]:http://brew.sh [gRPC install script]:https://raw.githubusercontent.com/grpc/homebrew-grpc/master/scripts/install -[Debian unstable]:https://www.debian.org/releases/sid/ +[Debian testing]:https://www.debian.org/releases/stretch/ diff --git a/src/node/ext/call.cc b/src/node/ext/call.cc index fddc1e214f..560869e6fa 100644 --- a/src/node/ext/call.cc +++ b/src/node/ext/call.cc @@ -461,6 +461,9 @@ void Call::Init(Handle<Object> exports) { NanNew<FunctionTemplate>(StartBatch)->GetFunction()); NanSetPrototypeTemplate(tpl, "cancel", NanNew<FunctionTemplate>(Cancel)->GetFunction()); + NanSetPrototypeTemplate( + tpl, "cancelWithStatus", + NanNew<FunctionTemplate>(CancelWithStatus)->GetFunction()); NanSetPrototypeTemplate(tpl, "getPeer", NanNew<FunctionTemplate>(GetPeer)->GetFunction()); NanAssignPersistent(fun_tpl, tpl); @@ -643,6 +646,26 @@ NAN_METHOD(Call::Cancel) { NanReturnUndefined(); } +NAN_METHOD(Call::CancelWithStatus) { + NanScope(); + if (!HasInstance(args.This())) { + return NanThrowTypeError("cancel can only be called on Call objects"); + } + if (!args[0]->IsUint32()) { + return NanThrowTypeError( + "cancelWithStatus's first argument must be a status code"); + } + if (!args[1]->IsString()) { + return NanThrowTypeError( + "cancelWithStatus's second argument must be a string"); + } + Call *call = ObjectWrap::Unwrap<Call>(args.This()); + grpc_status_code code = static_cast<grpc_status_code>(args[0]->Uint32Value()); + NanUtf8String details(args[0]); + grpc_call_cancel_with_status(call->wrapped_call, code, *details, NULL); + NanReturnUndefined(); +} + NAN_METHOD(Call::GetPeer) { NanScope(); if (!HasInstance(args.This())) { diff --git a/src/node/ext/call.h b/src/node/ext/call.h index ef6e5fcd21..89f81dcf4d 100644 --- a/src/node/ext/call.h +++ b/src/node/ext/call.h @@ -133,6 +133,7 @@ class Call : public ::node::ObjectWrap { static NAN_METHOD(New); static NAN_METHOD(StartBatch); static NAN_METHOD(Cancel); + static NAN_METHOD(CancelWithStatus); static NAN_METHOD(GetPeer); static NanCallback *constructor; // Used for typechecking instances of this javascript class diff --git a/src/node/ext/credentials.cc b/src/node/ext/credentials.cc index 85a823a108..c3b04dcea7 100644 --- a/src/node/ext/credentials.cc +++ b/src/node/ext/credentials.cc @@ -186,7 +186,7 @@ NAN_METHOD(Credentials::CreateComposite) { NAN_METHOD(Credentials::CreateGce) { NanScope(); - grpc_credentials *creds = grpc_compute_engine_credentials_create(NULL); + grpc_credentials *creds = grpc_google_compute_engine_credentials_create(NULL); if (creds == NULL) { NanReturnNull(); } @@ -204,7 +204,7 @@ NAN_METHOD(Credentials::CreateIam) { NanUtf8String auth_token(args[0]); NanUtf8String auth_selector(args[1]); grpc_credentials *creds = - grpc_iam_credentials_create(*auth_token, *auth_selector, NULL); + grpc_google_iam_credentials_create(*auth_token, *auth_selector, NULL); if (creds == NULL) { NanReturnNull(); } diff --git a/src/node/src/client.js b/src/node/src/client.js index e1bed3512e..6a49490910 100644 --- a/src/node/src/client.js +++ b/src/node/src/client.js @@ -142,7 +142,14 @@ function _read(size) { return; } var data = event.read; - if (self.push(self.deserialize(data)) && data !== null) { + var deserialized; + try { + deserialized = self.deserialize(data); + } catch (e) { + self.call.cancelWithStatus(grpc.status.INTERNAL, + 'Failed to parse server response'); + } + if (self.push(deserialized) && data !== null) { var read_batch = {}; read_batch[grpc.opType.RECV_MESSAGE] = true; self.call.startBatch(read_batch, readCallback); @@ -296,23 +303,38 @@ function makeUnaryRequestFunction(method, serialize, deserialize) { call.startBatch(client_batch, function(err, response) { response.status.metadata = Metadata._fromCoreRepresentation( response.status.metadata); - emitter.emit('status', response.status); - if (response.status.code !== grpc.status.OK) { - var error = new Error(response.status.details); - error.code = response.status.code; - error.metadata = response.status.metadata; - callback(error); - return; - } else { + var status = response.status; + var error; + var deserialized; + if (status.code === grpc.status.OK) { if (err) { // Got a batch error, but OK status. Something went wrong callback(err); return; + } else { + try { + deserialized = deserialize(response.read); + } catch (e) { + /* Change status to indicate bad server response. This will result + * in passing an error to the callback */ + status = { + code: grpc.status.INTERNAL, + details: 'Failed to parse server response' + }; + } } } + if (status.code !== grpc.status.OK) { + error = new Error(response.status.details); + error.code = status.code; + error.metadata = status.metadata; + callback(error); + } else { + callback(null, deserialized); + } + emitter.emit('status', status); emitter.emit('metadata', Metadata._fromCoreRepresentation( response.metadata)); - callback(null, deserialize(response.read)); }); }); return emitter; @@ -374,21 +396,36 @@ function makeClientStreamRequestFunction(method, serialize, deserialize) { call.startBatch(client_batch, function(err, response) { response.status.metadata = Metadata._fromCoreRepresentation( response.status.metadata); - stream.emit('status', response.status); - if (response.status.code !== grpc.status.OK) { - var error = new Error(response.status.details); - error.code = response.status.code; - error.metadata = response.status.metadata; - callback(error); - return; - } else { + var status = response.status; + var error; + var deserialized; + if (status.code === grpc.status.OK) { if (err) { // Got a batch error, but OK status. Something went wrong callback(err); return; + } else { + try { + deserialized = deserialize(response.read); + } catch (e) { + /* Change status to indicate bad server response. This will result + * in passing an error to the callback */ + status = { + code: grpc.status.INTERNAL, + details: 'Failed to parse server response' + }; + } } } - callback(null, deserialize(response.read)); + if (status.code !== grpc.status.OK) { + error = new Error(response.status.details); + error.code = status.code; + error.metadata = status.metadata; + callback(error); + } else { + callback(null, deserialized); + } + stream.emit('status', status); }); }); return stream; diff --git a/src/php/README.md b/src/php/README.md index f432935fde..afa09d79a1 100644 --- a/src/php/README.md +++ b/src/php/README.md @@ -32,10 +32,10 @@ $ sudo php -d detect_unicode=0 go-pear.phar **Linux (Debian):** -Add [Debian unstable][] to your `sources.list` file. Example: +Add [Debian testing][] to your `sources.list` file. Example: ```sh -echo "deb http://ftp.us.debian.org/debian unstable main contrib non-free" | \ +echo "deb http://ftp.us.debian.org/debian testing main contrib non-free" | \ sudo tee -a /etc/apt/sources.list ``` @@ -73,29 +73,24 @@ This will download and run the [gRPC install script][] and compile the gRPC PHP Clone this repository -``` +```sh $ git clone https://github.com/grpc/grpc.git ``` -Build and install the Protocol Buffers compiler (protoc) +Build and install the gRPC C core libraries -``` +```sh $ cd grpc $ git pull --recurse-submodules && git submodule update --init --recursive -$ cd third_party/protobuf -$ ./autogen.sh -$ ./configure $ make -$ make check $ sudo make install ``` -Build and install the gRPC C core libraries +Note: you may encounter a warning about the Protobuf compiler `protoc` 3.0.0+ not being installed. The following might help, and will be useful later on when we need to compile the `protoc-gen-php` tool. ```sh -$ cd grpc -$ make -$ sudo make install +$ cd grpc/third_party/protobuf +$ sudo make install # 'make' should have been run by core grpc ``` Install the gRPC PHP extension @@ -172,4 +167,4 @@ $ ./bin/run_gen_code_test.sh [homebrew]:http://brew.sh [gRPC install script]:https://raw.githubusercontent.com/grpc/homebrew-grpc/master/scripts/install [Node]:https://github.com/grpc/grpc/tree/master/src/node/examples -[Debian unstable]:https://www.debian.org/releases/sid/ +[Debian testing]:https://www.debian.org/releases/stretch/ diff --git a/src/php/ext/grpc/credentials.c b/src/php/ext/grpc/credentials.c index 0eba6608bb..8e3b7ff212 100644 --- a/src/php/ext/grpc/credentials.c +++ b/src/php/ext/grpc/credentials.c @@ -170,7 +170,7 @@ PHP_METHOD(Credentials, createComposite) { * @return Credentials The new GCE credentials object */ PHP_METHOD(Credentials, createGce) { - grpc_credentials *creds = grpc_compute_engine_credentials_create(NULL); + grpc_credentials *creds = grpc_google_compute_engine_credentials_create(NULL); zval *creds_object = grpc_php_wrap_credentials(creds); RETURN_DESTROY_ZVAL(creds_object); } diff --git a/src/python/README.md b/src/python/README.md index affce64884..a21deb33ef 100644 --- a/src/python/README.md +++ b/src/python/README.md @@ -16,10 +16,10 @@ INSTALLATION **Linux (Debian):** -Add [Debian unstable][] to your `sources.list` file. Example: +Add [Debian testing][] to your `sources.list` file. Example: ```sh -echo "deb http://ftp.us.debian.org/debian unstable main contrib non-free" | \ +echo "deb http://ftp.us.debian.org/debian testing main contrib non-free" | \ sudo tee -a /etc/apt/sources.list ``` @@ -92,4 +92,4 @@ $ ../../tools/distrib/python/submit.py [gRPC install script]:https://raw.githubusercontent.com/grpc/homebrew-grpc/master/scripts/install [Quick Start]:http://www.grpc.io/docs/tutorials/basic/python.html [detailed example]:http://www.grpc.io/docs/installation/python.html -[Debian unstable]:https://www.debian.org/releases/sid/ +[Debian testing]:https://www.debian.org/releases/stretch/ diff --git a/src/python/grpcio/grpc/_adapter/_c/types.h b/src/python/grpcio/grpc/_adapter/_c/types.h index f6ff957baa..ec0687a9fd 100644 --- a/src/python/grpcio/grpc/_adapter/_c/types.h +++ b/src/python/grpcio/grpc/_adapter/_c/types.h @@ -57,8 +57,6 @@ ClientCredentials *pygrpc_ClientCredentials_composite( PyTypeObject *type, PyObject *args, PyObject *kwargs); ClientCredentials *pygrpc_ClientCredentials_compute_engine( PyTypeObject *type, PyObject *ignored); -ClientCredentials *pygrpc_ClientCredentials_service_account( - PyTypeObject *type, PyObject *args, PyObject *kwargs); ClientCredentials *pygrpc_ClientCredentials_jwt( PyTypeObject *type, PyObject *args, PyObject *kwargs); ClientCredentials *pygrpc_ClientCredentials_refresh_token( diff --git a/src/python/grpcio/grpc/_adapter/_c/types/client_credentials.c b/src/python/grpcio/grpc/_adapter/_c/types/client_credentials.c index 36fd207464..90652b7b47 100644 --- a/src/python/grpcio/grpc/_adapter/_c/types/client_credentials.c +++ b/src/python/grpcio/grpc/_adapter/_c/types/client_credentials.c @@ -48,8 +48,6 @@ PyMethodDef pygrpc_ClientCredentials_methods[] = { METH_CLASS|METH_KEYWORDS, ""}, {"compute_engine", (PyCFunction)pygrpc_ClientCredentials_compute_engine, METH_CLASS|METH_NOARGS, ""}, - {"service_account", (PyCFunction)pygrpc_ClientCredentials_service_account, - METH_CLASS|METH_KEYWORDS, ""}, {"jwt", (PyCFunction)pygrpc_ClientCredentials_jwt, METH_CLASS|METH_KEYWORDS, ""}, {"refresh_token", (PyCFunction)pygrpc_ClientCredentials_refresh_token, @@ -173,7 +171,7 @@ ClientCredentials *pygrpc_ClientCredentials_composite( ClientCredentials *pygrpc_ClientCredentials_compute_engine( PyTypeObject *type, PyObject *ignored) { ClientCredentials *self = (ClientCredentials *)type->tp_alloc(type, 0); - self->c_creds = grpc_compute_engine_credentials_create(NULL); + self->c_creds = grpc_google_compute_engine_credentials_create(NULL); if (!self->c_creds) { Py_DECREF(self); PyErr_SetString(PyExc_RuntimeError, @@ -183,29 +181,6 @@ ClientCredentials *pygrpc_ClientCredentials_compute_engine( return self; } -ClientCredentials *pygrpc_ClientCredentials_service_account( - PyTypeObject *type, PyObject *args, PyObject *kwargs) { - ClientCredentials *self; - const char *json_key; - const char *scope; - double lifetime; - static char *keywords[] = {"json_key", "scope", "token_lifetime", NULL}; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "ssd:service_account", keywords, - &json_key, &scope, &lifetime)) { - return NULL; - } - self = (ClientCredentials *)type->tp_alloc(type, 0); - self->c_creds = grpc_service_account_credentials_create( - json_key, scope, pygrpc_cast_double_to_gpr_timespec(lifetime), NULL); - if (!self->c_creds) { - Py_DECREF(self); - PyErr_SetString(PyExc_RuntimeError, - "couldn't create service account credentials"); - return NULL; - } - return self; -} - /* TODO: Rename this credentials to something like service_account_jwt_access */ ClientCredentials *pygrpc_ClientCredentials_jwt( PyTypeObject *type, PyObject *args, PyObject *kwargs) { @@ -239,7 +214,7 @@ ClientCredentials *pygrpc_ClientCredentials_refresh_token( } self = (ClientCredentials *)type->tp_alloc(type, 0); self->c_creds = - grpc_refresh_token_credentials_create(json_refresh_token, NULL); + grpc_google_refresh_token_credentials_create(json_refresh_token, NULL); if (!self->c_creds) { Py_DECREF(self); PyErr_SetString(PyExc_RuntimeError, @@ -260,8 +235,8 @@ ClientCredentials *pygrpc_ClientCredentials_iam( return NULL; } self = (ClientCredentials *)type->tp_alloc(type, 0); - self->c_creds = grpc_iam_credentials_create(authorization_token, - authority_selector, NULL); + self->c_creds = grpc_google_iam_credentials_create(authorization_token, + authority_selector, NULL); if (!self->c_creds) { Py_DECREF(self); PyErr_SetString(PyExc_RuntimeError, "couldn't create IAM credentials"); diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx index 2d74702fbd..dc40a7a611 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx +++ b/src/python/grpcio/grpc/_cython/_cygrpc/credentials.pyx @@ -106,26 +106,6 @@ def client_credentials_compute_engine(): credentials.c_credentials = grpc.grpc_compute_engine_credentials_create() return credentials -def client_credentials_service_account( - json_key, scope, records.Timespec token_lifetime not None): - if isinstance(json_key, bytes): - pass - elif isinstance(json_key, basestring): - json_key = json_key.encode() - else: - raise TypeError("expected json_key to be str or bytes") - if isinstance(scope, bytes): - pass - elif isinstance(scope, basestring): - scope = scope.encode() - else: - raise TypeError("expected scope to be str or bytes") - cdef ClientCredentials credentials = ClientCredentials() - credentials.c_credentials = grpc.grpc_service_account_credentials_create( - json_key, scope, token_lifetime.c_time) - credentials.references.extend([json_key, scope]) - return credentials - #TODO rename to something like client_credentials_service_account_jwt_access. def client_credentials_jwt(json_key, records.Timespec token_lifetime not None): if isinstance(json_key, bytes): diff --git a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd index c793774c8d..8b46972490 100644 --- a/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd +++ b/src/python/grpcio/grpc/_cython/_cygrpc/grpc.pxd @@ -311,8 +311,6 @@ cdef extern from "grpc/grpc_security.h": grpc_credentials *grpc_composite_credentials_create(grpc_credentials *creds1, grpc_credentials *creds2) grpc_credentials *grpc_compute_engine_credentials_create() - grpc_credentials *grpc_service_account_credentials_create( - const char *json_key, const char *scope, gpr_timespec token_lifetime) grpc_credentials *grpc_service_account_jwt_access_credentials_create(const char *json_key, gpr_timespec token_lifetime) grpc_credentials *grpc_refresh_token_credentials_create( diff --git a/src/python/grpcio/grpc/_cython/adapter_low.py b/src/python/grpcio/grpc/_cython/adapter_low.py index 2bb468eece..4f24da330f 100644 --- a/src/python/grpcio/grpc/_cython/adapter_low.py +++ b/src/python/grpcio/grpc/_cython/adapter_low.py @@ -60,10 +60,6 @@ class ClientCredentials(object): raise NotImplementedError() @staticmethod - def service_account(): - raise NotImplementedError() - - @staticmethod def jwt(): raise NotImplementedError() diff --git a/src/python/grpcio/grpc/_links/invocation.py b/src/python/grpcio/grpc/_links/invocation.py index ee3d72fdbc..1676fe7941 100644 --- a/src/python/grpcio/grpc/_links/invocation.py +++ b/src/python/grpcio/grpc/_links/invocation.py @@ -41,6 +41,15 @@ from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay from grpc.framework.interfaces.links import links +_IDENTITY = lambda x: x + +_STOP = _intermediary_low.Event.Kind.STOP +_WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED +_COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED +_READ = _intermediary_low.Event.Kind.READ_ACCEPTED +_METADATA = _intermediary_low.Event.Kind.METADATA_ACCEPTED +_FINISH = _intermediary_low.Event.Kind.FINISH + @enum.unique class _Read(enum.Enum): @@ -67,7 +76,7 @@ class _RPCState(object): def __init__( self, call, request_serializer, response_deserializer, sequence_number, - read, allowance, high_write, low_write): + read, allowance, high_write, low_write, due): self.call = call self.request_serializer = request_serializer self.response_deserializer = response_deserializer @@ -76,27 +85,37 @@ class _RPCState(object): self.allowance = allowance self.high_write = high_write self.low_write = low_write + self.due = due + + +def _no_longer_due(kind, rpc_state, key, rpc_states): + rpc_state.due.remove(kind) + if not rpc_state.due: + del rpc_states[key] class _Kernel(object): def __init__( - self, channel, host, request_serializers, response_deserializers, - ticket_relay): + self, channel, host, metadata_transformer, request_serializers, + response_deserializers, ticket_relay): self._lock = threading.Lock() self._channel = channel self._host = host + self._metadata_transformer = metadata_transformer self._request_serializers = request_serializers self._response_deserializers = response_deserializers self._relay = ticket_relay self._completion_queue = None - self._rpc_states = None + self._rpc_states = {} self._pool = None def _on_write_event(self, operation_id, unused_event, rpc_state): if rpc_state.high_write is _HighWrite.CLOSED: rpc_state.call.complete(operation_id) + rpc_state.due.add(_COMPLETE) + rpc_state.due.remove(_WRITE) rpc_state.low_write = _LowWrite.CLOSED else: ticket = links.Ticket( @@ -105,16 +124,19 @@ class _Kernel(object): rpc_state.sequence_number += 1 self._relay.add_value(ticket) rpc_state.low_write = _LowWrite.OPEN + _no_longer_due(_WRITE, rpc_state, operation_id, self._rpc_states) def _on_read_event(self, operation_id, event, rpc_state): - if event.bytes is None: + if event.bytes is None or _FINISH not in rpc_state.due: rpc_state.read = _Read.CLOSED + _no_longer_due(_READ, rpc_state, operation_id, self._rpc_states) else: if 0 < rpc_state.allowance: rpc_state.allowance -= 1 rpc_state.call.read(operation_id) else: rpc_state.read = _Read.AWAITING_ALLOWANCE + _no_longer_due(_READ, rpc_state, operation_id, self._rpc_states) ticket = links.Ticket( operation_id, rpc_state.sequence_number, None, None, None, None, None, None, rpc_state.response_deserializer(event.bytes), None, None, None, @@ -123,18 +145,23 @@ class _Kernel(object): self._relay.add_value(ticket) def _on_metadata_event(self, operation_id, event, rpc_state): - rpc_state.allowance -= 1 - rpc_state.call.read(operation_id) - rpc_state.read = _Read.READING - ticket = links.Ticket( - operation_id, rpc_state.sequence_number, None, None, - links.Ticket.Subscription.FULL, None, None, event.metadata, None, None, - None, None, None, None) - rpc_state.sequence_number += 1 - self._relay.add_value(ticket) + if _FINISH in rpc_state.due: + rpc_state.allowance -= 1 + rpc_state.call.read(operation_id) + rpc_state.read = _Read.READING + rpc_state.due.add(_READ) + rpc_state.due.remove(_METADATA) + ticket = links.Ticket( + operation_id, rpc_state.sequence_number, None, None, + links.Ticket.Subscription.FULL, None, None, event.metadata, None, + None, None, None, None, None) + rpc_state.sequence_number += 1 + self._relay.add_value(ticket) + else: + _no_longer_due(_METADATA, rpc_state, operation_id, self._rpc_states) def _on_finish_event(self, operation_id, event, rpc_state): - self._rpc_states.pop(operation_id, None) + _no_longer_due(_FINISH, rpc_state, operation_id, self._rpc_states) if event.status.code is _intermediary_low.Code.OK: termination = links.Ticket.Termination.COMPLETION elif event.status.code is _intermediary_low.Code.CANCELLED: @@ -155,26 +182,26 @@ class _Kernel(object): def _spin(self, completion_queue): while True: event = completion_queue.get(None) - if event.kind is _intermediary_low.Event.Kind.STOP: - return - operation_id = event.tag with self._lock: - if self._completion_queue is None: - continue - rpc_state = self._rpc_states.get(operation_id) - if rpc_state is not None: - if event.kind is _intermediary_low.Event.Kind.WRITE_ACCEPTED: - self._on_write_event(operation_id, event, rpc_state) - elif event.kind is _intermediary_low.Event.Kind.METADATA_ACCEPTED: - self._on_metadata_event(operation_id, event, rpc_state) - elif event.kind is _intermediary_low.Event.Kind.READ_ACCEPTED: - self._on_read_event(operation_id, event, rpc_state) - elif event.kind is _intermediary_low.Event.Kind.FINISH: - self._on_finish_event(operation_id, event, rpc_state) - elif event.kind is _intermediary_low.Event.Kind.COMPLETE_ACCEPTED: - pass - else: - logging.error('Illegal RPC event! %s', (event,)) + rpc_state = self._rpc_states.get(event.tag, None) + if event.kind is _STOP: + pass + elif event.kind is _WRITE: + self._on_write_event(event.tag, event, rpc_state) + elif event.kind is _METADATA: + self._on_metadata_event(event.tag, event, rpc_state) + elif event.kind is _READ: + self._on_read_event(event.tag, event, rpc_state) + elif event.kind is _FINISH: + self._on_finish_event(event.tag, event, rpc_state) + elif event.kind is _COMPLETE: + _no_longer_due(_COMPLETE, rpc_state, event.tag, self._rpc_states) + else: + logging.error('Illegal RPC event! %s', (event,)) + + if self._completion_queue is None and not self._rpc_states: + completion_queue.stop() + return def _invoke( self, operation_id, group, method, initial_metadata, payload, termination, @@ -201,46 +228,48 @@ class _Kernel(object): else: return - request_serializer = self._request_serializers.get((group, method)) - response_deserializer = self._response_deserializers.get((group, method)) - if request_serializer is None or response_deserializer is None: - cancellation_ticket = links.Ticket( - operation_id, 0, None, None, None, None, None, None, None, None, None, - None, links.Ticket.Termination.CANCELLATION) - self._relay.add_value(cancellation_ticket) - return + transformed_initial_metadata = self._metadata_transformer(initial_metadata) + request_serializer = self._request_serializers.get( + (group, method), _IDENTITY) + response_deserializer = self._response_deserializers.get( + (group, method), _IDENTITY) call = _intermediary_low.Call( self._channel, self._completion_queue, '/%s/%s' % (group, method), self._host, time.time() + timeout) - if initial_metadata is not None: - for metadata_key, metadata_value in initial_metadata: + if transformed_initial_metadata is not None: + for metadata_key, metadata_value in transformed_initial_metadata: call.add_metadata(metadata_key, metadata_value) call.invoke(self._completion_queue, operation_id, operation_id) if payload is None: if high_write is _HighWrite.CLOSED: call.complete(operation_id) low_write = _LowWrite.CLOSED + due = set((_METADATA, _COMPLETE, _FINISH,)) else: low_write = _LowWrite.OPEN + due = set((_METADATA, _FINISH,)) else: call.write(request_serializer(payload), operation_id) low_write = _LowWrite.ACTIVE + due = set((_WRITE, _METADATA, _FINISH,)) self._rpc_states[operation_id] = _RPCState( call, request_serializer, response_deserializer, 0, _Read.AWAITING_METADATA, 1 if allowance is None else (1 + allowance), - high_write, low_write) + high_write, low_write, due) def _advance(self, operation_id, rpc_state, payload, termination, allowance): if payload is not None: rpc_state.call.write(rpc_state.request_serializer(payload), operation_id) rpc_state.low_write = _LowWrite.ACTIVE + rpc_state.due.add(_WRITE) if allowance is not None: if rpc_state.read is _Read.AWAITING_ALLOWANCE: rpc_state.allowance += allowance - 1 rpc_state.call.read(operation_id) rpc_state.read = _Read.READING + rpc_state.due.add(_READ) else: rpc_state.allowance += allowance @@ -248,19 +277,21 @@ class _Kernel(object): rpc_state.high_write = _HighWrite.CLOSED if rpc_state.low_write is _LowWrite.OPEN: rpc_state.call.complete(operation_id) + rpc_state.due.add(_COMPLETE) rpc_state.low_write = _LowWrite.CLOSED elif termination is not None: rpc_state.call.cancel() def add_ticket(self, ticket): with self._lock: - if self._completion_queue is None: - return if ticket.sequence_number == 0: - self._invoke( - ticket.operation_id, ticket.group, ticket.method, - ticket.initial_metadata, ticket.payload, ticket.termination, - ticket.timeout, ticket.allowance) + if self._completion_queue is None: + logging.error('Received invocation ticket %s after stop!', ticket) + else: + self._invoke( + ticket.operation_id, ticket.group, ticket.method, + ticket.initial_metadata, ticket.payload, ticket.termination, + ticket.timeout, ticket.allowance) else: rpc_state = self._rpc_states.get(ticket.operation_id) if rpc_state is not None: @@ -276,7 +307,6 @@ class _Kernel(object): """ with self._lock: self._completion_queue = _intermediary_low.CompletionQueue() - self._rpc_states = {} self._pool = logging_pool.pool(1) self._pool.submit(self._spin, self._completion_queue) @@ -288,11 +318,10 @@ class _Kernel(object): has been called. """ with self._lock: - self._completion_queue.stop() + if not self._rpc_states: + self._completion_queue.stop() self._completion_queue = None pool = self._pool - self._pool = None - self._rpc_states = None pool.shutdown(wait=True) @@ -307,10 +336,15 @@ class InvocationLink(links.Link, activated.Activated): class _InvocationLink(InvocationLink): def __init__( - self, channel, host, request_serializers, response_deserializers): + self, channel, host, metadata_transformer, request_serializers, + response_deserializers): self._relay = relay.relay(None) self._kernel = _Kernel( - channel, host, request_serializers, response_deserializers, self._relay) + channel, host, + _IDENTITY if metadata_transformer is None else metadata_transformer, + {} if request_serializers is None else request_serializers, + {} if response_deserializers is None else response_deserializers, + self._relay) def _start(self): self._relay.start() @@ -347,12 +381,17 @@ class _InvocationLink(InvocationLink): self._stop() -def invocation_link(channel, host, request_serializers, response_deserializers): +def invocation_link( + channel, host, metadata_transformer, request_serializers, + response_deserializers): """Creates an InvocationLink. Args: channel: An _intermediary_low.Channel for use by the link. host: The host to specify when invoking RPCs. + metadata_transformer: A callable that takes an invocation-side initial + metadata value and returns another metadata value to send in its place. + May be None. request_serializers: A dict from group-method pair to request object serialization behavior. response_deserializers: A dict from group-method pair to response object @@ -362,4 +401,5 @@ def invocation_link(channel, host, request_serializers, response_deserializers): An InvocationLink. """ return _InvocationLink( - channel, host, request_serializers, response_deserializers) + channel, host, metadata_transformer, request_serializers, + response_deserializers) diff --git a/src/python/grpcio/grpc/_links/service.py b/src/python/grpcio/grpc/_links/service.py index c5ecc47cd9..94e7cfc716 100644 --- a/src/python/grpcio/grpc/_links/service.py +++ b/src/python/grpcio/grpc/_links/service.py @@ -40,6 +40,8 @@ from grpc.framework.foundation import logging_pool from grpc.framework.foundation import relay from grpc.framework.interfaces.links import links +_IDENTITY = lambda x: x + _TERMINATION_KIND_TO_CODE = { links.Ticket.Termination.COMPLETION: _intermediary_low.Code.OK, links.Ticket.Termination.CANCELLATION: _intermediary_low.Code.CANCELLED, @@ -53,6 +55,13 @@ _TERMINATION_KIND_TO_CODE = { links.Ticket.Termination.REMOTE_FAILURE: _intermediary_low.Code.UNKNOWN, } +_STOP = _intermediary_low.Event.Kind.STOP +_WRITE = _intermediary_low.Event.Kind.WRITE_ACCEPTED +_COMPLETE = _intermediary_low.Event.Kind.COMPLETE_ACCEPTED +_SERVICE = _intermediary_low.Event.Kind.SERVICE_ACCEPTED +_READ = _intermediary_low.Event.Kind.READ_ACCEPTED +_FINISH = _intermediary_low.Event.Kind.FINISH + @enum.unique class _Read(enum.Enum): @@ -84,7 +93,7 @@ class _RPCState(object): def __init__( self, request_deserializer, response_serializer, sequence_number, read, early_read, allowance, high_write, low_write, premetadataed, - terminal_metadata, code, message): + terminal_metadata, code, message, due): self.request_deserializer = request_deserializer self.response_serializer = response_serializer self.sequence_number = sequence_number @@ -99,6 +108,13 @@ class _RPCState(object): self.terminal_metadata = terminal_metadata self.code = code self.message = message + self.due = due + + +def _no_longer_due(kind, rpc_state, key, rpc_states): + rpc_state.due.remove(kind) + if not rpc_state.due: + del rpc_states[key] def _metadatafy(call, metadata): @@ -124,6 +140,7 @@ class _Kernel(object): self._relay = ticket_relay self._completion_queue = None + self._due = set() self._server = None self._rpc_states = {} self._pool = None @@ -139,17 +156,16 @@ class _Kernel(object): except ValueError: logging.info('Illegal path "%s"!', service_acceptance.method) return - request_deserializer = self._request_deserializers.get((group, method)) - response_serializer = self._response_serializers.get((group, method)) - if request_deserializer is None or response_serializer is None: - # TODO(nathaniel): Terminate the RPC with code NOT_FOUND. - call.cancel() - return + request_deserializer = self._request_deserializers.get( + (group, method), _IDENTITY) + response_serializer = self._response_serializers.get( + (group, method), _IDENTITY) call.read(call) self._rpc_states[call] = _RPCState( request_deserializer, response_serializer, 1, _Read.READING, None, 1, - _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None) + _HighWrite.OPEN, _LowWrite.OPEN, False, None, None, None, + set((_READ, _FINISH,))) ticket = links.Ticket( call, 0, group, method, links.Ticket.Subscription.FULL, service_acceptance.deadline - time.time(), None, event.metadata, None, @@ -158,14 +174,13 @@ class _Kernel(object): def _on_read_event(self, event): call = event.tag - rpc_state = self._rpc_states.get(call, None) - if rpc_state is None: - return + rpc_state = self._rpc_states[call] if event.bytes is None: rpc_state.read = _Read.CLOSED payload = None termination = links.Ticket.Termination.COMPLETION + _no_longer_due(_READ, rpc_state, call, self._rpc_states) else: if 0 < rpc_state.allowance: payload = rpc_state.request_deserializer(event.bytes) @@ -174,6 +189,7 @@ class _Kernel(object): call.read(call) else: rpc_state.early_read = event.bytes + _no_longer_due(_READ, rpc_state, call, self._rpc_states) return # TODO(issue 2916): Instead of returning: # rpc_state.read = _Read.AWAITING_ALLOWANCE @@ -185,9 +201,7 @@ class _Kernel(object): def _on_write_event(self, event): call = event.tag - rpc_state = self._rpc_states.get(call, None) - if rpc_state is None: - return + rpc_state = self._rpc_states[call] if rpc_state.high_write is _HighWrite.CLOSED: if rpc_state.terminal_metadata is not None: @@ -197,6 +211,8 @@ class _Kernel(object): rpc_state.message) call.status(status, call) rpc_state.low_write = _LowWrite.CLOSED + rpc_state.due.add(_COMPLETE) + rpc_state.due.remove(_WRITE) else: ticket = links.Ticket( call, rpc_state.sequence_number, None, None, None, None, 1, None, @@ -204,12 +220,12 @@ class _Kernel(object): rpc_state.sequence_number += 1 self._relay.add_value(ticket) rpc_state.low_write = _LowWrite.OPEN + _no_longer_due(_WRITE, rpc_state, call, self._rpc_states) def _on_finish_event(self, event): call = event.tag - rpc_state = self._rpc_states.pop(call, None) - if rpc_state is None: - return + rpc_state = self._rpc_states[call] + _no_longer_due(_FINISH, rpc_state, call, self._rpc_states) code = event.status.code if code is _intermediary_low.Code.OK: return @@ -229,28 +245,33 @@ class _Kernel(object): def _spin(self, completion_queue, server): while True: event = completion_queue.get(None) - if event.kind is _intermediary_low.Event.Kind.STOP: - return with self._lock: - if self._server is None: - continue - elif event.kind is _intermediary_low.Event.Kind.SERVICE_ACCEPTED: - self._on_service_acceptance_event(event, server) - elif event.kind is _intermediary_low.Event.Kind.READ_ACCEPTED: + if event.kind is _STOP: + self._due.remove(_STOP) + elif event.kind is _READ: self._on_read_event(event) - elif event.kind is _intermediary_low.Event.Kind.WRITE_ACCEPTED: + elif event.kind is _WRITE: self._on_write_event(event) - elif event.kind is _intermediary_low.Event.Kind.COMPLETE_ACCEPTED: - pass + elif event.kind is _COMPLETE: + _no_longer_due( + _COMPLETE, self._rpc_states.get(event.tag), event.tag, + self._rpc_states) elif event.kind is _intermediary_low.Event.Kind.FINISH: self._on_finish_event(event) + elif event.kind is _SERVICE: + if self._server is None: + self._due.remove(_SERVICE) + else: + self._on_service_acceptance_event(event, server) else: logging.error('Illegal event! %s', (event,)) + if not self._due and not self._rpc_states: + completion_queue.stop() + return + def add_ticket(self, ticket): with self._lock: - if self._server is None: - return call = ticket.operation_id rpc_state = self._rpc_states.get(call) if rpc_state is None: @@ -278,6 +299,7 @@ class _Kernel(object): rpc_state.early_read = None if rpc_state.read is _Read.READING: call.read(call) + rpc_state.due.add(_READ) termination = None else: termination = links.Ticket.Termination.COMPLETION @@ -289,6 +311,7 @@ class _Kernel(object): if ticket.payload is not None: call.write(rpc_state.response_serializer(ticket.payload), call) + rpc_state.due.add(_WRITE) rpc_state.low_write = _LowWrite.ACTIVE if ticket.terminal_metadata is not None: @@ -307,6 +330,7 @@ class _Kernel(object): links.Ticket.Termination.COMPLETION, rpc_state.code, rpc_state.message) call.status(status, call) + rpc_state.due.add(_COMPLETE) rpc_state.low_write = _LowWrite.CLOSED elif ticket.termination is not None: if rpc_state.terminal_metadata is not None: @@ -314,7 +338,7 @@ class _Kernel(object): status = _status( ticket.termination, rpc_state.code, rpc_state.message) call.status(status, call) - self._rpc_states.pop(call, None) + rpc_state.due.add(_COMPLETE) def add_port(self, address, server_credentials): with self._lock: @@ -335,19 +359,17 @@ class _Kernel(object): self._pool.submit(self._spin, self._completion_queue, self._server) self._server.start() self._server.service(None) + self._due.add(_SERVICE) def begin_stop(self): with self._lock: self._server.stop() + self._due.add(_STOP) self._server = None def end_stop(self): with self._lock: - self._completion_queue.stop() - self._completion_queue = None pool = self._pool - self._pool = None - self._rpc_states = None pool.shutdown(wait=True) @@ -369,7 +391,7 @@ class ServiceLink(links.Link): None for insecure service. Returns: - A integer port on which RPCs will be serviced after this link has been + An integer port on which RPCs will be serviced after this link has been started. This is typically the same number as the port number contained in the passed address, but will likely be different if the port number contained in the passed address was zero. @@ -411,7 +433,9 @@ class _ServiceLink(ServiceLink): def __init__(self, request_deserializers, response_serializers): self._relay = relay.relay(None) self._kernel = _Kernel( - request_deserializers, response_serializers, self._relay) + {} if request_deserializers is None else request_deserializers, + {} if response_serializers is None else response_serializers, + self._relay) def accept_ticket(self, ticket): self._kernel.add_ticket(ticket) diff --git a/src/python/grpcio/grpc/beta/_server.py b/src/python/grpcio/grpc/beta/_server.py new file mode 100644 index 0000000000..4e46ffd17f --- /dev/null +++ b/src/python/grpcio/grpc/beta/_server.py @@ -0,0 +1,112 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Beta API server implementation.""" + +import threading + +from grpc._links import service +from grpc.framework.core import implementations as _core_implementations +from grpc.framework.crust import implementations as _crust_implementations +from grpc.framework.foundation import logging_pool +from grpc.framework.interfaces.links import utilities + +_DEFAULT_POOL_SIZE = 8 +_DEFAULT_TIMEOUT = 300 +_MAXIMUM_TIMEOUT = 24 * 60 * 60 + + +def _disassemble(grpc_link, end_link, pool, event, grace): + grpc_link.begin_stop() + end_link.stop(grace).wait() + grpc_link.end_stop() + grpc_link.join_link(utilities.NULL_LINK) + end_link.join_link(utilities.NULL_LINK) + if pool is not None: + pool.shutdown(wait=True) + event.set() + + +class Server(object): + + def __init__(self, grpc_link, end_link, pool): + self._grpc_link = grpc_link + self._end_link = end_link + self._pool = pool + + def add_insecure_port(self, address): + return self._grpc_link.add_port(address, None) + + def add_secure_port(self, address, intermediary_low_server_credentials): + return self._grpc_link.add_port( + address, intermediary_low_server_credentials) + + def start(self): + self._grpc_link.join_link(self._end_link) + self._end_link.join_link(self._grpc_link) + self._grpc_link.start() + self._end_link.start() + + def stop(self, grace): + stop_event = threading.Event() + if 0 < grace: + disassembly_thread = threading.Thread( + target=_disassemble, + args=( + self._grpc_link, self._end_link, self._pool, stop_event, grace,)) + disassembly_thread.start() + return stop_event + else: + _disassemble(self._grpc_link, self._end_link, self._pool, stop_event, 0) + return stop_event + + +def server( + implementations, multi_implementation, request_deserializers, + response_serializers, thread_pool, thread_pool_size, default_timeout, + maximum_timeout): + if thread_pool is None: + service_thread_pool = logging_pool.pool( + _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size) + assembly_thread_pool = service_thread_pool + else: + service_thread_pool = thread_pool + assembly_thread_pool = None + + servicer = _crust_implementations.servicer( + implementations, multi_implementation, service_thread_pool) + + grpc_link = service.service_link(request_deserializers, response_serializers) + + end_link = _core_implementations.service_end_link( + servicer, + _DEFAULT_TIMEOUT if default_timeout is None else default_timeout, + _MAXIMUM_TIMEOUT if maximum_timeout is None else maximum_timeout) + + return Server(grpc_link, end_link, assembly_thread_pool) diff --git a/src/python/grpcio/grpc/beta/_stub.py b/src/python/grpcio/grpc/beta/_stub.py new file mode 100644 index 0000000000..cfbecb852b --- /dev/null +++ b/src/python/grpcio/grpc/beta/_stub.py @@ -0,0 +1,111 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Beta API stub implementation.""" + +import threading + +from grpc._links import invocation +from grpc.framework.core import implementations as _core_implementations +from grpc.framework.crust import implementations as _crust_implementations +from grpc.framework.foundation import logging_pool +from grpc.framework.interfaces.links import utilities + +_DEFAULT_POOL_SIZE = 6 + + +class _AutoIntermediary(object): + + def __init__(self, delegate, on_deletion): + self._delegate = delegate + self._on_deletion = on_deletion + + def __getattr__(self, attr): + return getattr(self._delegate, attr) + + def __del__(self): + self._on_deletion() + + +def _assemble( + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size): + end_link = _core_implementations.invocation_end_link() + grpc_link = invocation.invocation_link( + channel, host, metadata_transformer, request_serializers, + response_deserializers) + if thread_pool is None: + invocation_pool = logging_pool.pool( + _DEFAULT_POOL_SIZE if thread_pool_size is None else thread_pool_size) + assembly_pool = invocation_pool + else: + invocation_pool = thread_pool + assembly_pool = None + end_link.join_link(grpc_link) + grpc_link.join_link(end_link) + end_link.start() + grpc_link.start() + return end_link, grpc_link, invocation_pool, assembly_pool + + +def _disassemble(end_link, grpc_link, pool): + end_link.stop(24 * 60 * 60).wait() + grpc_link.stop() + end_link.join_link(utilities.NULL_LINK) + grpc_link.join_link(utilities.NULL_LINK) + if pool is not None: + pool.shutdown(wait=True) + + +def _wrap_assembly(stub, end_link, grpc_link, assembly_pool): + disassembly_thread = threading.Thread( + target=_disassemble, args=(end_link, grpc_link, assembly_pool)) + return _AutoIntermediary(stub, disassembly_thread.start) + + +def generic_stub( + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size): + end_link, grpc_link, invocation_pool, assembly_pool = _assemble( + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size) + stub = _crust_implementations.generic_stub(end_link, invocation_pool) + return _wrap_assembly(stub, end_link, grpc_link, assembly_pool) + + +def dynamic_stub( + channel, host, service, cardinalities, metadata_transformer, + request_serializers, response_deserializers, thread_pool, + thread_pool_size): + end_link, grpc_link, invocation_pool, assembly_pool = _assemble( + channel, host, metadata_transformer, request_serializers, + response_deserializers, thread_pool, thread_pool_size) + stub = _crust_implementations.dynamic_stub( + end_link, service, cardinalities, invocation_pool) + return _wrap_assembly(stub, end_link, grpc_link, assembly_pool) diff --git a/src/python/grpcio/grpc/beta/beta.py b/src/python/grpcio/grpc/beta/beta.py index 40cad5e486..b3a161087f 100644 --- a/src/python/grpcio/grpc/beta/beta.py +++ b/src/python/grpcio/grpc/beta/beta.py @@ -27,13 +27,21 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Entry points into gRPC Python Beta.""" +"""Entry points into the Beta API of gRPC Python.""" +# threading is referenced from specification in this module. +import abc import enum +import threading # pylint: disable=unused-import -from grpc._adapter import _low +# cardinality and face are referenced from specification in this module. +from grpc._adapter import _intermediary_low from grpc._adapter import _types from grpc.beta import _connectivity_channel +from grpc.beta import _server +from grpc.beta import _stub +from grpc.framework.common import cardinality # pylint: disable=unused-import +from grpc.framework.interfaces.face import face # pylint: disable=unused-import _CHANNEL_SUBSCRIPTION_CALLBACK_ERROR_LOG_MESSAGE = ( 'Exception calling channel subscription callback!') @@ -65,6 +73,39 @@ _LOW_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY = { } +class ClientCredentials(object): + """A value encapsulating the data required to create a secure Channel. + + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ + + def __init__(self, low_credentials, intermediary_low_credentials): + self._low_credentials = low_credentials + self._intermediary_low_credentials = intermediary_low_credentials + + +def ssl_client_credentials(root_certificates, private_key, certificate_chain): + """Creates a ClientCredentials for use with an SSL-enabled Channel. + + Args: + root_certificates: The PEM-encoded root certificates or None to ask for + them to be retrieved from a default location. + private_key: The PEM-encoded private key to use or None if no private key + should be used. + certificate_chain: The PEM-encoded certificate chain to use or None if no + certificate chain should be used. + + Returns: + A ClientCredentials for use with an SSL-enabled Channel. + """ + intermediary_low_credentials = _intermediary_low.ClientCredentials( + root_certificates, private_key, certificate_chain) + return ClientCredentials( + intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access + + class Channel(object): """A channel to a remote host through which RPCs may be conducted. @@ -73,7 +114,9 @@ class Channel(object): unsupported. """ - def __init__(self, low_channel): + def __init__(self, low_channel, intermediary_low_channel): + self._low_channel = low_channel + self._intermediary_low_channel = intermediary_low_channel self._connectivity_channel = _connectivity_channel.ConnectivityChannel( low_channel, _LOW_CONNECTIVITY_STATE_TO_CHANNEL_CONNECTIVITY) @@ -111,4 +154,338 @@ def create_insecure_channel(host, port): Returns: A Channel to the remote host through which RPCs may be conducted. """ - return Channel(_low.Channel('%s:%d' % (host, port), ())) + intermediary_low_channel = _intermediary_low.Channel( + '%s:%d' % (host, port), None) + return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access + + +def create_secure_channel(host, port, client_credentials): + """Creates a secure Channel to a remote host. + + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + client_credentials: A ClientCredentials. + + Returns: + A secure Channel to the remote host through which RPCs may be conducted. + """ + intermediary_low_channel = _intermediary_low.Channel( + '%s:%d' % (host, port), client_credentials.intermediary_low_credentials) + return Channel(intermediary_low_channel._internal, intermediary_low_channel) # pylint: disable=protected-access + + +class StubOptions(object): + """A value encapsulating the various options for creation of a Stub. + + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ + + def __init__( + self, host, request_serializers, response_deserializers, + metadata_transformer, thread_pool, thread_pool_size): + self.host = host + self.request_serializers = request_serializers + self.response_deserializers = response_deserializers + self.metadata_transformer = metadata_transformer + self.thread_pool = thread_pool + self.thread_pool_size = thread_pool_size + +_EMPTY_STUB_OPTIONS = StubOptions( + None, None, None, None, None, None) + + +def stub_options( + host=None, request_serializers=None, response_deserializers=None, + metadata_transformer=None, thread_pool=None, thread_pool_size=None): + """Creates a StubOptions value to be passed at stub creation. + + All parameters are optional and should always be passed by keyword. + + Args: + host: A host string to set on RPC calls. + request_serializers: A dictionary from service name-method name pair to + request serialization behavior. + response_deserializers: A dictionary from service name-method name pair to + response deserialization behavior. + metadata_transformer: A callable that given a metadata object produces + another metadata object to be used in the underlying communication on the + wire. + thread_pool: A thread pool to use in stubs. + thread_pool_size: The size of thread pool to create for use in stubs; + ignored if thread_pool has been passed. + + Returns: + A StubOptions value created from the passed parameters. + """ + return StubOptions( + host, request_serializers, response_deserializers, + metadata_transformer, thread_pool, thread_pool_size) + + +def generic_stub(channel, options=None): + """Creates a face.GenericStub on which RPCs can be made. + + Args: + channel: A Channel for use by the created stub. + options: A StubOptions customizing the created stub. + + Returns: + A face.GenericStub on which RPCs can be made. + """ + effective_options = _EMPTY_STUB_OPTIONS if options is None else options + return _stub.generic_stub( + channel._intermediary_low_channel, effective_options.host, # pylint: disable=protected-access + effective_options.metadata_transformer, + effective_options.request_serializers, + effective_options.response_deserializers, effective_options.thread_pool, + effective_options.thread_pool_size) + + +def dynamic_stub(channel, service, cardinalities, options=None): + """Creates a face.DynamicStub with which RPCs can be invoked. + + Args: + channel: A Channel for the returned face.DynamicStub to use. + service: The package-qualified full name of the service. + cardinalities: A dictionary from RPC method name to cardinality.Cardinality + value identifying the cardinality of the RPC method. + options: An optional StubOptions value further customizing the functionality + of the returned face.DynamicStub. + + Returns: + A face.DynamicStub with which RPCs can be invoked. + """ + effective_options = StubOptions() if options is None else options + return _stub.dynamic_stub( + channel._intermediary_low_channel, effective_options.host, service, # pylint: disable=protected-access + cardinalities, effective_options.metadata_transformer, + effective_options.request_serializers, + effective_options.response_deserializers, effective_options.thread_pool, + effective_options.thread_pool_size) + + +class ServerCredentials(object): + """A value encapsulating the data required to open a secure port on a Server. + + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ + + def __init__(self, low_credentials, intermediary_low_credentials): + self._low_credentials = low_credentials + self._intermediary_low_credentials = intermediary_low_credentials + + +def ssl_server_credentials( + private_key_certificate_chain_pairs, root_certificates=None, + require_client_auth=False): + """Creates a ServerCredentials for use with an SSL-enabled Server. + + Args: + private_key_certificate_chain_pairs: A nonempty sequence each element of + which is a pair the first element of which is a PEM-encoded private key + and the second element of which is the corresponding PEM-encoded + certificate chain. + root_certificates: PEM-encoded client root certificates to be used for + verifying authenticated clients. If omitted, require_client_auth must also + be omitted or be False. + require_client_auth: A boolean indicating whether or not to require clients + to be authenticated. May only be True if root_certificates is not None. + + Returns: + A ServerCredentials for use with an SSL-enabled Server. + """ + if len(private_key_certificate_chain_pairs) == 0: + raise ValueError( + 'At least one private key-certificate chain pairis required!') + elif require_client_auth and root_certificates is None: + raise ValueError( + 'Illegal to require client auth without providing root certificates!') + else: + intermediary_low_credentials = _intermediary_low.ServerCredentials( + root_certificates, private_key_certificate_chain_pairs, + require_client_auth) + return ServerCredentials( + intermediary_low_credentials._internal, intermediary_low_credentials) # pylint: disable=protected-access + + +class Server(object): + """Services RPCs.""" + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def add_insecure_port(self, address): + """Reserves a port for insecure RPC service once this Server becomes active. + + This method may only be called before calling this Server's start method is + called. + + Args: + address: The address for which to open a port. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ + raise NotImplementedError() + + @abc.abstractmethod + def add_secure_port(self, address, server_credentials): + """Reserves a port for secure RPC service after this Server becomes active. + + This method may only be called before calling this Server's start method is + called. + + Args: + address: The address for which to open a port. + server_credentials: A ServerCredentials. + + Returns: + An integer port on which RPCs will be serviced after this link has been + started. This is typically the same number as the port number contained + in the passed address, but will likely be different if the port number + contained in the passed address was zero. + """ + raise NotImplementedError() + + @abc.abstractmethod + def start(self): + """Starts this Server's service of RPCs. + + This method may only be called while the server is not serving RPCs (i.e. it + is not idempotent). + """ + raise NotImplementedError() + + @abc.abstractmethod + def stop(self, grace): + """Stops this Server's service of RPCs. + + All calls to this method immediately stop service of new RPCs. When existing + RPCs are aborted is controlled by the grace period parameter passed to this + method. + + This method may be called at any time and is idempotent. Passing a smaller + grace value than has been passed in a previous call will have the effect of + stopping the Server sooner. Passing a larger grace value than has been + passed in a previous call will not have the effect of stopping the sooner + later. + + Args: + grace: A duration of time in seconds to allow existing RPCs to complete + before being aborted by this Server's stopping. May be zero for + immediate abortion of all in-progress RPCs. + + Returns: + A threading.Event that will be set when this Server has completely + stopped. The returned event may not be set until after the full grace + period (if some ongoing RPC continues for the full length of the period) + of it may be set much sooner (such as if this Server had no RPCs underway + at the time it was stopped or if all RPCs that it had underway completed + very early in the grace period). + """ + raise NotImplementedError() + + +class ServerOptions(object): + """A value encapsulating the various options for creation of a Server. + + This class and its instances have no supported interface - it exists to define + the type of its instances and its instances exist to be passed to other + functions. + """ + + def __init__( + self, multi_method_implementation, request_deserializers, + response_serializers, thread_pool, thread_pool_size, default_timeout, + maximum_timeout): + self.multi_method_implementation = multi_method_implementation + self.request_deserializers = request_deserializers + self.response_serializers = response_serializers + self.thread_pool = thread_pool + self.thread_pool_size = thread_pool_size + self.default_timeout = default_timeout + self.maximum_timeout = maximum_timeout + +_EMPTY_SERVER_OPTIONS = ServerOptions( + None, None, None, None, None, None, None) + + +def server_options( + multi_method_implementation=None, request_deserializers=None, + response_serializers=None, thread_pool=None, thread_pool_size=None, + default_timeout=None, maximum_timeout=None): + """Creates a ServerOptions value to be passed at server creation. + + All parameters are optional and should always be passed by keyword. + + Args: + multi_method_implementation: A face.MultiMethodImplementation to be called + to service an RPC if the server has no specific method implementation for + the name of the RPC for which service was requested. + request_deserializers: A dictionary from service name-method name pair to + request deserialization behavior. + response_serializers: A dictionary from service name-method name pair to + response serialization behavior. + thread_pool: A thread pool to use in stubs. + thread_pool_size: The size of thread pool to create for use in stubs; + ignored if thread_pool has been passed. + default_timeout: A duration in seconds to allow for RPC service when + servicing RPCs that did not include a timeout value when invoked. + maximum_timeout: A duration in seconds to allow for RPC service when + servicing RPCs no matter what timeout value was passed when the RPC was + invoked. + + Returns: + A StubOptions value created from the passed parameters. + """ + return ServerOptions( + multi_method_implementation, request_deserializers, response_serializers, + thread_pool, thread_pool_size, default_timeout, maximum_timeout) + + +class _Server(Server): + + def __init__(self, underserver): + self._underserver = underserver + + def add_insecure_port(self, address): + return self._underserver.add_insecure_port(address) + + def add_secure_port(self, address, server_credentials): + return self._underserver.add_secure_port( + address, server_credentials._intermediary_low_credentials) # pylint: disable=protected-access + + def start(self): + self._underserver.start() + + def stop(self, grace): + return self._underserver.stop(grace) + + +def server(service_implementations, options=None): + """Creates a Server with which RPCs can be serviced. + + Args: + service_implementations: A dictionary from service name-method name pair to + face.MethodImplementation. + options: An optional ServerOptions value further customizing the + functionality of the returned Server. + + Returns: + A Server with which RPCs can be serviced. + """ + effective_options = _EMPTY_SERVER_OPTIONS if options is None else options + underserver = _server.server( + service_implementations, effective_options.multi_method_implementation, + effective_options.request_deserializers, + effective_options.response_serializers, effective_options.thread_pool, + effective_options.thread_pool_size, effective_options.default_timeout, + effective_options.maximum_timeout) + return _Server(underserver) diff --git a/src/python/grpcio/grpc/framework/core/_end.py b/src/python/grpcio/grpc/framework/core/_end.py index 5ef2f6d3a3..f57cde4e58 100644 --- a/src/python/grpcio/grpc/framework/core/_end.py +++ b/src/python/grpcio/grpc/framework/core/_end.py @@ -203,11 +203,11 @@ class _End(End): def accept_ticket(self, ticket): """See links.Link.accept_ticket for specification.""" with self._lock: - if self._cycle is not None and not self._cycle.grace: + if self._cycle is not None: operation = self._cycle.operations.get(ticket.operation_id) if operation is not None: operation.handle_ticket(ticket) - elif self._servicer_package is not None: + elif self._servicer_package is not None and not self._cycle.grace: termination_action = _termination_action( self._lock, self._stats, ticket.operation_id, self._cycle) operation = _operation.service_operate( diff --git a/src/python/grpcio_test/grpc_interop/methods.py b/src/python/grpcio_test/grpc_interop/methods.py index 642458e892..52b800af7a 100644 --- a/src/python/grpcio_test/grpc_interop/methods.py +++ b/src/python/grpcio_test/grpc_interop/methods.py @@ -346,20 +346,6 @@ def _compute_engine_creds(stub, args): response.username)) -def _service_account_creds(stub, args): - json_key_filename = os.environ[ - oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] - wanted_email = json.load(open(json_key_filename, 'rb'))['client_email'] - response = _large_unary_common_behavior(stub, True, True) - if wanted_email != response.username: - raise ValueError( - 'expected username %s, got %s' % (wanted_email, response.username)) - if args.oauth_scope.find(response.oauth_scope) == -1: - raise ValueError( - 'expected to find oauth scope "%s" in received "%s"' % - (response.oauth_scope, args.oauth_scope)) - - def _oauth2_auth_token(stub, args): json_key_filename = os.environ[ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS] @@ -383,7 +369,6 @@ class TestCase(enum.Enum): CANCEL_AFTER_BEGIN = 'cancel_after_begin' CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response' COMPUTE_ENGINE_CREDS = 'compute_engine_creds' - SERVICE_ACCOUNT_CREDS = 'service_account_creds' OAUTH2_AUTH_TOKEN = 'oauth2_auth_token' TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server' @@ -406,8 +391,6 @@ class TestCase(enum.Enum): _timeout_on_sleeping_server(stub) elif self is TestCase.COMPUTE_ENGINE_CREDS: _compute_engine_creds(stub, args) - elif self is TestCase.SERVICE_ACCOUNT_CREDS: - _service_account_creds(stub, args) elif self is TestCase.OAUTH2_AUTH_TOKEN: _oauth2_auth_token(stub, args) else: diff --git a/src/python/grpcio_test/grpc_protoc_plugin/alpha_python_plugin_test.py b/src/python/grpcio_test/grpc_protoc_plugin/alpha_python_plugin_test.py new file mode 100644 index 0000000000..b200d129a9 --- /dev/null +++ b/src/python/grpcio_test/grpc_protoc_plugin/alpha_python_plugin_test.py @@ -0,0 +1,541 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import contextlib +import distutils.spawn +import errno +import itertools +import os +import pkg_resources +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import unittest + +from grpc.framework.alpha import exceptions +from grpc.framework.foundation import future + +# Identifiers of entities we expect to find in the generated module. +SERVICER_IDENTIFIER = 'EarlyAdopterTestServiceServicer' +SERVER_IDENTIFIER = 'EarlyAdopterTestServiceServer' +STUB_IDENTIFIER = 'EarlyAdopterTestServiceStub' +SERVER_FACTORY_IDENTIFIER = 'early_adopter_create_TestService_server' +STUB_FACTORY_IDENTIFIER = 'early_adopter_create_TestService_stub' + +# The timeout used in tests of RPCs that are supposed to expire. +SHORT_TIMEOUT = 2 +# The timeout used in tests of RPCs that are not supposed to expire. The +# absurdly large value doesn't matter since no passing execution of this test +# module will ever wait the duration. +LONG_TIMEOUT = 600 +NO_DELAY = 0 + + +class _ServicerMethods(object): + + def __init__(self, test_pb2, delay): + self._condition = threading.Condition() + self._delay = delay + self._paused = False + self._fail = False + self._test_pb2 = test_pb2 + + @contextlib.contextmanager + def pause(self): # pylint: disable=invalid-name + with self._condition: + self._paused = True + yield + with self._condition: + self._paused = False + self._condition.notify_all() + + @contextlib.contextmanager + def fail(self): # pylint: disable=invalid-name + with self._condition: + self._fail = True + yield + with self._condition: + self._fail = False + + def _control(self): # pylint: disable=invalid-name + with self._condition: + if self._fail: + raise ValueError() + while self._paused: + self._condition.wait() + time.sleep(self._delay) + + def UnaryCall(self, request, unused_rpc_context): + response = self._test_pb2.SimpleResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * request.response_size + self._control() + return response + + def StreamingOutputCall(self, request, unused_rpc_context): + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def StreamingInputCall(self, request_iter, unused_rpc_context): + response = self._test_pb2.StreamingInputCallResponse() + aggregated_payload_size = 0 + for request in request_iter: + aggregated_payload_size += len(request.payload.payload_compressable) + response.aggregated_payload_size = aggregated_payload_size + self._control() + return response + + def FullDuplexCall(self, request_iter, unused_rpc_context): + for request in request_iter: + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def HalfDuplexCall(self, request_iter, unused_rpc_context): + responses = [] + for request in request_iter: + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + responses.append(response) + for response in responses: + yield response + + +@contextlib.contextmanager +def _CreateService(test_pb2, delay): + """Provides a servicer backend and a stub. + + The servicer is just the implementation + of the actual servicer passed to the face player of the python RPC + implementation; the two are detached. + + Non-zero delay puts a delay on each call to the servicer, representative of + communication latency. Timeout is the default timeout for the stub while + waiting for the service. + + Args: + test_pb2: The test_pb2 module generated by this test. + delay: Delay in seconds per response from the servicer. + + Yields: + A (servicer_methods, servicer, stub) three-tuple where servicer_methods is + the back-end of the service bound to the stub and the server and stub + are both activated and ready for use. + """ + servicer_methods = _ServicerMethods(test_pb2, delay) + + class Servicer(getattr(test_pb2, SERVICER_IDENTIFIER)): + + def UnaryCall(self, request, context): + return servicer_methods.UnaryCall(request, context) + + def StreamingOutputCall(self, request, context): + return servicer_methods.StreamingOutputCall(request, context) + + def StreamingInputCall(self, request_iter, context): + return servicer_methods.StreamingInputCall(request_iter, context) + + def FullDuplexCall(self, request_iter, context): + return servicer_methods.FullDuplexCall(request_iter, context) + + def HalfDuplexCall(self, request_iter, context): + return servicer_methods.HalfDuplexCall(request_iter, context) + + servicer = Servicer() + server = getattr( + test_pb2, SERVER_FACTORY_IDENTIFIER)(servicer, 0) + with server: + port = server.port() + stub = getattr(test_pb2, STUB_FACTORY_IDENTIFIER)('localhost', port) + with stub: + yield servicer_methods, stub, server + + +def _streaming_input_request_iterator(test_pb2): + for _ in range(3): + request = test_pb2.StreamingInputCallRequest() + request.payload.payload_type = test_pb2.COMPRESSABLE + request.payload.payload_compressable = 'a' + yield request + + +def _streaming_output_request(test_pb2): + request = test_pb2.StreamingOutputCallRequest() + sizes = [1, 2, 3] + request.response_parameters.add(size=sizes[0], interval_us=0) + request.response_parameters.add(size=sizes[1], interval_us=0) + request.response_parameters.add(size=sizes[2], interval_us=0) + return request + + +def _full_duplex_request_iterator(test_pb2): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + + +class PythonPluginTest(unittest.TestCase): + """Test case for the gRPC Python protoc-plugin. + + While reading these tests, remember that the futures API + (`stub.method.async()`) only gives futures for the *non-streaming* responses, + else it behaves like its blocking cousin. + """ + + def setUp(self): + # Assume that the appropriate protoc and grpc_python_plugins are on the + # path. + protoc_command = 'protoc' + protoc_plugin_filename = distutils.spawn.find_executable( + 'grpc_python_plugin') + test_proto_filename = pkg_resources.resource_filename( + 'grpc_protoc_plugin', 'test.proto') + if not os.path.isfile(protoc_command): + # Assume that if we haven't built protoc that it's on the system. + protoc_command = 'protoc' + + # Ensure that the output directory exists. + self.outdir = tempfile.mkdtemp() + + # Invoke protoc with the plugin. + cmd = [ + protoc_command, + '--plugin=protoc-gen-python-grpc=%s' % protoc_plugin_filename, + '-I .', + '--python_out=%s' % self.outdir, + '--python-grpc_out=%s' % self.outdir, + os.path.basename(test_proto_filename), + ] + subprocess.check_call(' '.join(cmd), shell=True, env=os.environ, + cwd=os.path.dirname(test_proto_filename)) + sys.path.append(self.outdir) + + def tearDown(self): + try: + shutil.rmtree(self.outdir) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + + # TODO(atash): Figure out which of these tests is hanging flakily with small + # probability. + + def testImportAttributes(self): + # check that we can access the generated module and its members. + import test_pb2 # pylint: disable=g-import-not-at-top + self.assertIsNotNone(getattr(test_pb2, SERVICER_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, SERVER_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, STUB_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, SERVER_FACTORY_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, STUB_FACTORY_IDENTIFIER, None)) + + def testUpDown(self): + import test_pb2 + with _CreateService( + test_pb2, NO_DELAY) as (servicer, stub, unused_server): + request = test_pb2.SimpleRequest(response_size=13) + + def testUnaryCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + timeout = 6 # TODO(issue 2039): LONG_TIMEOUT like the other methods. + request = test_pb2.SimpleRequest(response_size=13) + response = stub.UnaryCall(request, timeout) + expected_response = methods.UnaryCall(request, 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testUnaryCallAsync(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + # Check that the call does not block waiting for the server to respond. + with methods.pause(): + response_future = stub.UnaryCall.async(request, LONG_TIMEOUT) + response = response_future.result() + expected_response = methods.UnaryCall(request, 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testUnaryCallAsyncExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + request = test_pb2.SimpleRequest(response_size=13) + with methods.pause(): + response_future = stub.UnaryCall.async(request, SHORT_TIMEOUT) + with self.assertRaises(exceptions.ExpirationError): + response_future.result() + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testUnaryCallAsyncCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + response_future = stub.UnaryCall.async(request, 1) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + + def testUnaryCallAsyncFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.fail(): + response_future = stub.UnaryCall.async(request, LONG_TIMEOUT) + self.assertIsNotNone(response_future.exception()) + + def testStreamingOutputCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + responses = stub.StreamingOutputCall(request, LONG_TIMEOUT) + expected_responses = methods.StreamingOutputCall( + request, 'not a real RpcContext!') + for expected_response, response in itertools.izip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testStreamingOutputCallExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + responses = stub.StreamingOutputCall(request, SHORT_TIMEOUT) + with self.assertRaises(exceptions.ExpirationError): + list(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testStreamingOutputCallCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as ( + unused_methods, stub, unused_server): + responses = stub.StreamingOutputCall(request, SHORT_TIMEOUT) + next(responses) + responses.cancel() + with self.assertRaises(future.CancelledError): + next(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this times out ' + 'instead of raising the proper error.') + def testStreamingOutputCallFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.fail(): + responses = stub.StreamingOutputCall(request, 1) + self.assertIsNotNone(responses) + with self.assertRaises(exceptions.ServicerError): + next(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testStreamingInputCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + response = stub.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), LONG_TIMEOUT) + expected_response = methods.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallAsync(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + response_future = stub.StreamingInputCall.async( + _streaming_input_request_iterator(test_pb2), LONG_TIMEOUT) + response = response_future.result() + expected_response = methods.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallAsyncExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + response_future = stub.StreamingInputCall.async( + _streaming_input_request_iterator(test_pb2), SHORT_TIMEOUT) + with self.assertRaises(exceptions.ExpirationError): + response_future.result() + self.assertIsInstance( + response_future.exception(), exceptions.ExpirationError) + + def testStreamingInputCallAsyncCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + timeout = 6 # TODO(issue 2039): LONG_TIMEOUT like the other methods. + response_future = stub.StreamingInputCall.async( + _streaming_input_request_iterator(test_pb2), timeout) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + with self.assertRaises(future.CancelledError): + response_future.result() + + def testStreamingInputCallAsyncFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.fail(): + response_future = stub.StreamingInputCall.async( + _streaming_input_request_iterator(test_pb2), SHORT_TIMEOUT) + self.assertIsNotNone(response_future.exception()) + + def testFullDuplexCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + responses = stub.FullDuplexCall( + _full_duplex_request_iterator(test_pb2), LONG_TIMEOUT) + expected_responses = methods.FullDuplexCall( + _full_duplex_request_iterator(test_pb2), 'not a real RpcContext!') + for expected_response, response in itertools.izip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testFullDuplexCallExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request_iterator = _full_duplex_request_iterator(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.pause(): + responses = stub.FullDuplexCall(request_iterator, SHORT_TIMEOUT) + with self.assertRaises(exceptions.ExpirationError): + list(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testFullDuplexCallCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + request_iterator = _full_duplex_request_iterator(test_pb2) + responses = stub.FullDuplexCall(request_iterator, LONG_TIMEOUT) + next(responses) + responses.cancel() + with self.assertRaises(future.CancelledError): + next(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this hangs forever ' + 'and fix.') + def testFullDuplexCallFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request_iterator = _full_duplex_request_iterator(test_pb2) + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + with methods.fail(): + responses = stub.FullDuplexCall(request_iterator, LONG_TIMEOUT) + self.assertIsNotNone(responses) + with self.assertRaises(exceptions.ServicerError): + next(responses) + + @unittest.skip('TODO(atash,nathaniel): figure out why this flakily hangs ' + 'forever and fix.') + def testHalfDuplexCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2, NO_DELAY) as ( + methods, stub, unused_server): + def half_duplex_request_iterator(): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), LONG_TIMEOUT) + expected_responses = methods.HalfDuplexCall( + half_duplex_request_iterator(), 'not a real RpcContext!') + for check in itertools.izip_longest(expected_responses, responses): + expected_response, response = check + self.assertEqual(expected_response, response) + + def testHalfDuplexCallWedged(self): + import test_pb2 # pylint: disable=g-import-not-at-top + condition = threading.Condition() + wait_cell = [False] + @contextlib.contextmanager + def wait(): # pylint: disable=invalid-name + # Where's Python 3's 'nonlocal' statement when you need it? + with condition: + wait_cell[0] = True + yield + with condition: + wait_cell[0] = False + condition.notify_all() + def half_duplex_request_iterator(): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + with condition: + while wait_cell[0]: + condition.wait() + with _CreateService(test_pb2, NO_DELAY) as (methods, stub, unused_server): + with wait(): + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), SHORT_TIMEOUT) + # half-duplex waits for the client to send all info + with self.assertRaises(exceptions.ExpirationError): + next(responses) + + +if __name__ == '__main__': + os.chdir(os.path.dirname(sys.argv[0])) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_test/grpc_protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_test/grpc_protoc_plugin/beta_python_plugin_test.py new file mode 100644 index 0000000000..4c8c64b06d --- /dev/null +++ b/src/python/grpcio_test/grpc_protoc_plugin/beta_python_plugin_test.py @@ -0,0 +1,501 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import contextlib +import distutils.spawn +import errno +import itertools +import os +import pkg_resources +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import unittest + +from grpc.beta import beta +from grpc.framework.foundation import future +from grpc.framework.interfaces.face import face +from grpc_test.framework.common import test_constants + +# Identifiers of entities we expect to find in the generated module. +SERVICER_IDENTIFIER = 'BetaTestServiceServicer' +STUB_IDENTIFIER = 'BetaTestServiceStub' +SERVER_FACTORY_IDENTIFIER = 'beta_create_TestService_server' +STUB_FACTORY_IDENTIFIER = 'beta_create_TestService_stub' + + +class _ServicerMethods(object): + + def __init__(self, test_pb2): + self._condition = threading.Condition() + self._paused = False + self._fail = False + self._test_pb2 = test_pb2 + + @contextlib.contextmanager + def pause(self): # pylint: disable=invalid-name + with self._condition: + self._paused = True + yield + with self._condition: + self._paused = False + self._condition.notify_all() + + @contextlib.contextmanager + def fail(self): # pylint: disable=invalid-name + with self._condition: + self._fail = True + yield + with self._condition: + self._fail = False + + def _control(self): # pylint: disable=invalid-name + with self._condition: + if self._fail: + raise ValueError() + while self._paused: + self._condition.wait() + + def UnaryCall(self, request, unused_rpc_context): + response = self._test_pb2.SimpleResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * request.response_size + self._control() + return response + + def StreamingOutputCall(self, request, unused_rpc_context): + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def StreamingInputCall(self, request_iter, unused_rpc_context): + response = self._test_pb2.StreamingInputCallResponse() + aggregated_payload_size = 0 + for request in request_iter: + aggregated_payload_size += len(request.payload.payload_compressable) + response.aggregated_payload_size = aggregated_payload_size + self._control() + return response + + def FullDuplexCall(self, request_iter, unused_rpc_context): + for request in request_iter: + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + yield response + + def HalfDuplexCall(self, request_iter, unused_rpc_context): + responses = [] + for request in request_iter: + for parameter in request.response_parameters: + response = self._test_pb2.StreamingOutputCallResponse() + response.payload.payload_type = self._test_pb2.COMPRESSABLE + response.payload.payload_compressable = 'a' * parameter.size + self._control() + responses.append(response) + for response in responses: + yield response + + +@contextlib.contextmanager +def _CreateService(test_pb2): + """Provides a servicer backend and a stub. + + The servicer is just the implementation of the actual servicer passed to the + face player of the python RPC implementation; the two are detached. + + Args: + test_pb2: The test_pb2 module generated by this test. + + Yields: + A (servicer_methods, stub) pair where servicer_methods is the back-end of + the service bound to the stub and and stub is the stub on which to invoke + RPCs. + """ + servicer_methods = _ServicerMethods(test_pb2) + + class Servicer(getattr(test_pb2, SERVICER_IDENTIFIER)): + + def UnaryCall(self, request, context): + return servicer_methods.UnaryCall(request, context) + + def StreamingOutputCall(self, request, context): + return servicer_methods.StreamingOutputCall(request, context) + + def StreamingInputCall(self, request_iter, context): + return servicer_methods.StreamingInputCall(request_iter, context) + + def FullDuplexCall(self, request_iter, context): + return servicer_methods.FullDuplexCall(request_iter, context) + + def HalfDuplexCall(self, request_iter, context): + return servicer_methods.HalfDuplexCall(request_iter, context) + + servicer = Servicer() + server = getattr(test_pb2, SERVER_FACTORY_IDENTIFIER)(servicer) + port = server.add_insecure_port('[::]:0') + server.start() + channel = beta.create_insecure_channel('localhost', port) + stub = getattr(test_pb2, STUB_FACTORY_IDENTIFIER)(channel) + yield servicer_methods, stub + server.stop(0) + + +def _streaming_input_request_iterator(test_pb2): + for _ in range(3): + request = test_pb2.StreamingInputCallRequest() + request.payload.payload_type = test_pb2.COMPRESSABLE + request.payload.payload_compressable = 'a' + yield request + + +def _streaming_output_request(test_pb2): + request = test_pb2.StreamingOutputCallRequest() + sizes = [1, 2, 3] + request.response_parameters.add(size=sizes[0], interval_us=0) + request.response_parameters.add(size=sizes[1], interval_us=0) + request.response_parameters.add(size=sizes[2], interval_us=0) + return request + + +def _full_duplex_request_iterator(test_pb2): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + + +class PythonPluginTest(unittest.TestCase): + """Test case for the gRPC Python protoc-plugin. + + While reading these tests, remember that the futures API + (`stub.method.future()`) only gives futures for the *response-unary* + methods and does not exist for response-streaming methods. + """ + + def setUp(self): + # Assume that the appropriate protoc and grpc_python_plugins are on the + # path. + protoc_command = 'protoc' + protoc_plugin_filename = distutils.spawn.find_executable( + 'grpc_python_plugin') + test_proto_filename = pkg_resources.resource_filename( + 'grpc_protoc_plugin', 'test.proto') + if not os.path.isfile(protoc_command): + # Assume that if we haven't built protoc that it's on the system. + protoc_command = 'protoc' + + # Ensure that the output directory exists. + self.outdir = tempfile.mkdtemp() + + # Invoke protoc with the plugin. + cmd = [ + protoc_command, + '--plugin=protoc-gen-python-grpc=%s' % protoc_plugin_filename, + '-I .', + '--python_out=%s' % self.outdir, + '--python-grpc_out=%s' % self.outdir, + os.path.basename(test_proto_filename), + ] + subprocess.check_call(' '.join(cmd), shell=True, env=os.environ, + cwd=os.path.dirname(test_proto_filename)) + sys.path.append(self.outdir) + + def tearDown(self): + try: + shutil.rmtree(self.outdir) + except OSError as exc: + if exc.errno != errno.ENOENT: + raise + + def testImportAttributes(self): + # check that we can access the generated module and its members. + import test_pb2 # pylint: disable=g-import-not-at-top + self.assertIsNotNone(getattr(test_pb2, SERVICER_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, STUB_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, SERVER_FACTORY_IDENTIFIER, None)) + self.assertIsNotNone(getattr(test_pb2, STUB_FACTORY_IDENTIFIER, None)) + + def testUpDown(self): + import test_pb2 + with _CreateService(test_pb2) as (servicer, stub): + request = test_pb2.SimpleRequest(response_size=13) + + def testUnaryCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + request = test_pb2.SimpleRequest(response_size=13) + response = stub.UnaryCall(request, test_constants.LONG_TIMEOUT) + expected_response = methods.UnaryCall(request, 'not a real context!') + self.assertEqual(expected_response, response) + + def testUnaryCallFuture(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2) as (methods, stub): + # Check that the call does not block waiting for the server to respond. + with methods.pause(): + response_future = stub.UnaryCall.future( + request, test_constants.LONG_TIMEOUT) + response = response_future.result() + expected_response = methods.UnaryCall(request, 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testUnaryCallFutureExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + request = test_pb2.SimpleRequest(response_size=13) + with methods.pause(): + response_future = stub.UnaryCall.future( + request, test_constants.SHORT_TIMEOUT) + with self.assertRaises(face.ExpirationError): + response_future.result() + + def testUnaryCallFutureCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + response_future = stub.UnaryCall.future(request, 1) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + + def testUnaryCallFutureFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = test_pb2.SimpleRequest(response_size=13) + with _CreateService(test_pb2) as (methods, stub): + with methods.fail(): + response_future = stub.UnaryCall.future( + request, test_constants.LONG_TIMEOUT) + self.assertIsNotNone(response_future.exception()) + + def testStreamingOutputCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2) as (methods, stub): + responses = stub.StreamingOutputCall( + request, test_constants.LONG_TIMEOUT) + expected_responses = methods.StreamingOutputCall( + request, 'not a real RpcContext!') + for expected_response, response in itertools.izip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testStreamingOutputCallExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + responses = stub.StreamingOutputCall( + request, test_constants.SHORT_TIMEOUT) + with self.assertRaises(face.ExpirationError): + list(responses) + + def testStreamingOutputCallCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2) as (unused_methods, stub): + responses = stub.StreamingOutputCall( + request, test_constants.LONG_TIMEOUT) + next(responses) + responses.cancel() + with self.assertRaises(face.CancellationError): + next(responses) + + def testStreamingOutputCallFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request = _streaming_output_request(test_pb2) + with _CreateService(test_pb2) as (methods, stub): + with methods.fail(): + responses = stub.StreamingOutputCall(request, 1) + self.assertIsNotNone(responses) + with self.assertRaises(face.RemoteError): + next(responses) + + def testStreamingInputCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + response = stub.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), + test_constants.LONG_TIMEOUT) + expected_response = methods.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallFuture(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + response_future = stub.StreamingInputCall.future( + _streaming_input_request_iterator(test_pb2), + test_constants.LONG_TIMEOUT) + response = response_future.result() + expected_response = methods.StreamingInputCall( + _streaming_input_request_iterator(test_pb2), 'not a real RpcContext!') + self.assertEqual(expected_response, response) + + def testStreamingInputCallFutureExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + response_future = stub.StreamingInputCall.future( + _streaming_input_request_iterator(test_pb2), + test_constants.SHORT_TIMEOUT) + with self.assertRaises(face.ExpirationError): + response_future.result() + self.assertIsInstance( + response_future.exception(), face.ExpirationError) + + def testStreamingInputCallFutureCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + response_future = stub.StreamingInputCall.future( + _streaming_input_request_iterator(test_pb2), + test_constants.LONG_TIMEOUT) + response_future.cancel() + self.assertTrue(response_future.cancelled()) + with self.assertRaises(future.CancelledError): + response_future.result() + + def testStreamingInputCallFutureFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + with methods.fail(): + response_future = stub.StreamingInputCall.future( + _streaming_input_request_iterator(test_pb2), + test_constants.LONG_TIMEOUT) + self.assertIsNotNone(response_future.exception()) + + def testFullDuplexCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + responses = stub.FullDuplexCall( + _full_duplex_request_iterator(test_pb2), test_constants.LONG_TIMEOUT) + expected_responses = methods.FullDuplexCall( + _full_duplex_request_iterator(test_pb2), 'not a real RpcContext!') + for expected_response, response in itertools.izip_longest( + expected_responses, responses): + self.assertEqual(expected_response, response) + + def testFullDuplexCallExpired(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request_iterator = _full_duplex_request_iterator(test_pb2) + with _CreateService(test_pb2) as (methods, stub): + with methods.pause(): + responses = stub.FullDuplexCall( + request_iterator, test_constants.SHORT_TIMEOUT) + with self.assertRaises(face.ExpirationError): + list(responses) + + def testFullDuplexCallCancelled(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + request_iterator = _full_duplex_request_iterator(test_pb2) + responses = stub.FullDuplexCall( + request_iterator, test_constants.LONG_TIMEOUT) + next(responses) + responses.cancel() + with self.assertRaises(face.CancellationError): + next(responses) + + def testFullDuplexCallFailed(self): + import test_pb2 # pylint: disable=g-import-not-at-top + request_iterator = _full_duplex_request_iterator(test_pb2) + with _CreateService(test_pb2) as (methods, stub): + with methods.fail(): + responses = stub.FullDuplexCall( + request_iterator, test_constants.LONG_TIMEOUT) + self.assertIsNotNone(responses) + with self.assertRaises(face.RemoteError): + next(responses) + + def testHalfDuplexCall(self): + import test_pb2 # pylint: disable=g-import-not-at-top + with _CreateService(test_pb2) as (methods, stub): + def half_duplex_request_iterator(): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=2, interval_us=0) + request.response_parameters.add(size=3, interval_us=0) + yield request + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), test_constants.LONG_TIMEOUT) + expected_responses = methods.HalfDuplexCall( + half_duplex_request_iterator(), 'not a real RpcContext!') + for check in itertools.izip_longest(expected_responses, responses): + expected_response, response = check + self.assertEqual(expected_response, response) + + def testHalfDuplexCallWedged(self): + import test_pb2 # pylint: disable=g-import-not-at-top + condition = threading.Condition() + wait_cell = [False] + @contextlib.contextmanager + def wait(): # pylint: disable=invalid-name + # Where's Python 3's 'nonlocal' statement when you need it? + with condition: + wait_cell[0] = True + yield + with condition: + wait_cell[0] = False + condition.notify_all() + def half_duplex_request_iterator(): + request = test_pb2.StreamingOutputCallRequest() + request.response_parameters.add(size=1, interval_us=0) + yield request + with condition: + while wait_cell[0]: + condition.wait() + with _CreateService(test_pb2) as (methods, stub): + with wait(): + responses = stub.HalfDuplexCall( + half_duplex_request_iterator(), test_constants.SHORT_TIMEOUT) + # half-duplex waits for the client to send all info + with self.assertRaises(face.ExpirationError): + next(responses) + + +if __name__ == '__main__': + os.chdir(os.path.dirname(sys.argv[0])) + unittest.main(verbosity=2) diff --git a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py index 9112c34190..f0bd989ea6 100644 --- a/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py +++ b/src/python/grpcio_test/grpc_test/_core_over_links_base_interface_test.py @@ -94,7 +94,7 @@ class _Implementation(test_interfaces.Implementation): port = service_grpc_link.add_port('[::]:0', None) channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_grpc_link = invocation.invocation_link( - channel, b'localhost', + channel, b'localhost', None, serialization_behaviors.request_serializers, serialization_behaviors.response_deserializers) diff --git a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py index 1401536503..28c0619f7c 100644 --- a/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py +++ b/src/python/grpcio_test/grpc_test/_crust_over_core_over_links_face_interface_test.py @@ -87,7 +87,7 @@ class _Implementation(test_interfaces.Implementation): port = service_grpc_link.add_port('[::]:0', None) channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_grpc_link = invocation.invocation_link( - channel, b'localhost', + channel, b'localhost', None, serialization_behaviors.request_serializers, serialization_behaviors.response_deserializers) diff --git a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py index 373a2b2a1f..8e12e8cc22 100644 --- a/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py +++ b/src/python/grpcio_test/grpc_test/_links/_lonely_invocation_link_test.py @@ -45,7 +45,8 @@ class LonelyInvocationLinkTest(unittest.TestCase): def testUpAndDown(self): channel = _intermediary_low.Channel('nonexistent:54321', None) - invocation_link = invocation.invocation_link(channel, 'nonexistent', {}, {}) + invocation_link = invocation.invocation_link( + channel, 'nonexistent', None, {}, {}) invocation_link.start() invocation_link.stop() @@ -58,8 +59,7 @@ class LonelyInvocationLinkTest(unittest.TestCase): channel = _intermediary_low.Channel('nonexistent:54321', None) invocation_link = invocation.invocation_link( - channel, 'nonexistent', {(test_group, test_method): _NULL_BEHAVIOR}, - {(test_group, test_method): _NULL_BEHAVIOR}) + channel, 'nonexistent', None, {}, {}) invocation_link.join_link(invocation_link_mate) invocation_link.start() diff --git a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py index c114cef6a6..716323cc20 100644 --- a/src/python/grpcio_test/grpc_test/_links/_transmission_test.py +++ b/src/python/grpcio_test/grpc_test/_links/_transmission_test.py @@ -54,7 +54,7 @@ class TransmissionTest(test_cases.TransmissionTest, unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', + channel, 'localhost', None, {self.group_and_method(): self.serialize_request}, {self.group_and_method(): self.deserialize_response}) invocation_link.start() @@ -121,7 +121,7 @@ class RoundTripTest(unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', identity_transformation, identity_transformation) + channel, None, None, identity_transformation, identity_transformation) invocation_mate = test_utilities.RecordingLink() invocation_link.join_link(invocation_mate) invocation_link.start() @@ -166,7 +166,7 @@ class RoundTripTest(unittest.TestCase): service_link.start() channel = _intermediary_low.Channel('localhost:%d' % port, None) invocation_link = invocation.invocation_link( - channel, 'localhost', + channel, 'localhost', None, {(test_group, test_method): scenario.serialize_request}, {(test_group, test_method): scenario.deserialize_response}) invocation_mate = test_utilities.RecordingLink() diff --git a/src/python/grpcio_test/grpc_test/beta/_face_interface_test.py b/src/python/grpcio_test/grpc_test/beta/_face_interface_test.py new file mode 100644 index 0000000000..ce4c59c0ee --- /dev/null +++ b/src/python/grpcio_test/grpc_test/beta/_face_interface_test.py @@ -0,0 +1,137 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests Face interface compliance of the gRPC Python Beta API.""" + +import collections +import unittest + +from grpc._adapter import _intermediary_low +from grpc.beta import beta +from grpc_test import resources +from grpc_test import test_common as grpc_test_common +from grpc_test.beta import test_utilities +from grpc_test.framework.common import test_constants +from grpc_test.framework.interfaces.face import test_cases +from grpc_test.framework.interfaces.face import test_interfaces + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' + + +class _SerializationBehaviors( + collections.namedtuple( + '_SerializationBehaviors', + ('request_serializers', 'request_deserializers', 'response_serializers', + 'response_deserializers',))): + pass + + +def _serialization_behaviors_from_test_methods(test_methods): + request_serializers = {} + request_deserializers = {} + response_serializers = {} + response_deserializers = {} + for (group, method), test_method in test_methods.iteritems(): + request_serializers[group, method] = test_method.serialize_request + request_deserializers[group, method] = test_method.deserialize_request + response_serializers[group, method] = test_method.serialize_response + response_deserializers[group, method] = test_method.deserialize_response + return _SerializationBehaviors( + request_serializers, request_deserializers, response_serializers, + response_deserializers) + + +class _Implementation(test_interfaces.Implementation): + + def instantiate( + self, methods, method_implementations, multi_method_implementation): + serialization_behaviors = _serialization_behaviors_from_test_methods( + methods) + # TODO(nathaniel): Add a "groups" attribute to _digest.TestServiceDigest. + service = next(iter(methods))[0] + # TODO(nathaniel): Add a "cardinalities_by_group" attribute to + # _digest.TestServiceDigest. + cardinalities = { + method: method_object.cardinality() + for (group, method), method_object in methods.iteritems()} + + server_options = beta.server_options( + request_deserializers=serialization_behaviors.request_deserializers, + response_serializers=serialization_behaviors.response_serializers, + thread_pool_size=test_constants.POOL_SIZE) + server = beta.server(method_implementations, options=server_options) + server_credentials = beta.ssl_server_credentials( + [(resources.private_key(), resources.certificate_chain(),),]) + port = server.add_secure_port('[::]:0', server_credentials) + server.start() + client_credentials = beta.ssl_client_credentials( + resources.test_root_certificates(), None, None) + channel = test_utilities.create_not_really_secure_channel( + 'localhost', port, client_credentials, _SERVER_HOST_OVERRIDE) + stub_options = beta.stub_options( + request_serializers=serialization_behaviors.request_serializers, + response_deserializers=serialization_behaviors.response_deserializers, + thread_pool_size=test_constants.POOL_SIZE) + generic_stub = beta.generic_stub(channel, options=stub_options) + dynamic_stub = beta.dynamic_stub( + channel, service, cardinalities, options=stub_options) + return generic_stub, {service: dynamic_stub}, server + + def destantiate(self, memo): + memo.stop(test_constants.SHORT_TIMEOUT).wait() + + def invocation_metadata(self): + return grpc_test_common.INVOCATION_INITIAL_METADATA + + def initial_metadata(self): + return grpc_test_common.SERVICE_INITIAL_METADATA + + def terminal_metadata(self): + return grpc_test_common.SERVICE_TERMINAL_METADATA + + def code(self): + return _intermediary_low.Code.OK + + def details(self): + return grpc_test_common.DETAILS + + def metadata_transmitted(self, original_metadata, transmitted_metadata): + return original_metadata is None or grpc_test_common.metadata_transmitted( + original_metadata, transmitted_metadata) + + +def load_tests(loader, tests, pattern): + return unittest.TestSuite( + tests=tuple( + loader.loadTestsFromTestCase(test_case_class) + for test_case_class in test_cases.test_cases(_Implementation()))) + + +if __name__ == '__main__': + unittest.main(verbosity=2) diff --git a/src/python/grpcio_test/grpc_test/beta/test_utilities.py b/src/python/grpcio_test/grpc_test/beta/test_utilities.py new file mode 100644 index 0000000000..338670478d --- /dev/null +++ b/src/python/grpcio_test/grpc_test/beta/test_utilities.py @@ -0,0 +1,54 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test-appropriate entry points into the gRPC Python Beta API.""" + +from grpc._adapter import _intermediary_low +from grpc.beta import beta + + +def create_not_really_secure_channel( + host, port, client_credentials, server_host_override): + """Creates an insecure Channel to a remote host. + + Args: + host: The name of the remote host to which to connect. + port: The port of the remote host to which to connect. + client_credentials: The beta.ClientCredentials with which to connect. + server_host_override: The target name used for SSL host name checking. + + Returns: + A beta.Channel to the remote host through which RPCs may be conducted. + """ + hostport = '%s:%d' % (host, port) + intermediary_low_channel = _intermediary_low.Channel( + hostport, client_credentials._intermediary_low_credentials, + server_host_override=server_host_override) + return beta.Channel( + intermediary_low_channel._internal, intermediary_low_channel) diff --git a/src/python/grpcio_test/grpc_test/credentials/README b/src/python/grpcio_test/grpc_test/credentials/README new file mode 100644 index 0000000000..cb20dcb49f --- /dev/null +++ b/src/python/grpcio_test/grpc_test/credentials/README @@ -0,0 +1 @@ +These are test keys *NOT* to be used in production. diff --git a/src/python/grpcio_test/grpc_test/credentials/ca.pem b/src/python/grpcio_test/grpc_test/credentials/ca.pem new file mode 100755 index 0000000000..6c8511a73c --- /dev/null +++ b/src/python/grpcio_test/grpc_test/credentials/ca.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICSjCCAbOgAwIBAgIJAJHGGR4dGioHMA0GCSqGSIb3DQEBCwUAMFYxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQxDzANBgNVBAMTBnRlc3RjYTAeFw0xNDExMTEyMjMxMjla +Fw0yNDExMDgyMjMxMjlaMFYxCzAJBgNVBAYTAkFVMRMwEQYDVQQIEwpTb21lLVN0 +YXRlMSEwHwYDVQQKExhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxDzANBgNVBAMT +BnRlc3RjYTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEAwEDfBV5MYdlHVHJ7 ++L4nxrZy7mBfAVXpOc5vMYztssUI7mL2/iYujiIXM+weZYNTEpLdjyJdu7R5gGUu +g1jSVK/EPHfc74O7AyZU34PNIP4Sh33N+/A5YexrNgJlPY+E3GdVYi4ldWJjgkAd +Qah2PH5ACLrIIC6tRka9hcaBlIECAwEAAaMgMB4wDAYDVR0TBAUwAwEB/zAOBgNV +HQ8BAf8EBAMCAgQwDQYJKoZIhvcNAQELBQADgYEAHzC7jdYlzAVmddi/gdAeKPau +sPBG/C2HCWqHzpCUHcKuvMzDVkY/MP2o6JIW2DBbY64bO/FceExhjcykgaYtCH/m +oIU63+CFOTtR7otyQAWHqXa7q4SbCDlG7DyRFxqG0txPtGvy12lgldA2+RgcigQG +Dfcog5wrJytaQ6UA0wE= +-----END CERTIFICATE----- diff --git a/src/python/grpcio_test/grpc_test/credentials/server1.key b/src/python/grpcio_test/grpc_test/credentials/server1.key new file mode 100755 index 0000000000..143a5b8765 --- /dev/null +++ b/src/python/grpcio_test/grpc_test/credentials/server1.key @@ -0,0 +1,16 @@ +-----BEGIN PRIVATE KEY----- +MIICdQIBADANBgkqhkiG9w0BAQEFAASCAl8wggJbAgEAAoGBAOHDFScoLCVJpYDD +M4HYtIdV6Ake/sMNaaKdODjDMsux/4tDydlumN+fm+AjPEK5GHhGn1BgzkWF+slf +3BxhrA/8dNsnunstVA7ZBgA/5qQxMfGAq4wHNVX77fBZOgp9VlSMVfyd9N8YwbBY +AckOeUQadTi2X1S6OgJXgQ0m3MWhAgMBAAECgYAn7qGnM2vbjJNBm0VZCkOkTIWm +V10okw7EPJrdL2mkre9NasghNXbE1y5zDshx5Nt3KsazKOxTT8d0Jwh/3KbaN+YY +tTCbKGW0pXDRBhwUHRcuRzScjli8Rih5UOCiZkhefUTcRb6xIhZJuQy71tjaSy0p +dHZRmYyBYO2YEQ8xoQJBAPrJPhMBkzmEYFtyIEqAxQ/o/A6E+E4w8i+KM7nQCK7q +K4JXzyXVAjLfyBZWHGM2uro/fjqPggGD6QH1qXCkI4MCQQDmdKeb2TrKRh5BY1LR +81aJGKcJ2XbcDu6wMZK4oqWbTX2KiYn9GB0woM6nSr/Y6iy1u145YzYxEV/iMwff +DJULAkB8B2MnyzOg0pNFJqBJuH29bKCcHa8gHJzqXhNO5lAlEbMK95p/P2Wi+4Hd +aiEIAF1BF326QJcvYKmwSmrORp85AkAlSNxRJ50OWrfMZnBgzVjDx3xG6KsFQVk2 +ol6VhqL6dFgKUORFUWBvnKSyhjJxurlPEahV6oo6+A+mPhFY8eUvAkAZQyTdupP3 +XEFQKctGz+9+gKkemDp7LBBMEMBXrGTLPhpEfcjv/7KPdnFHYmhYeBTBnuVmTVWe +F98XJ7tIFfJq +-----END PRIVATE KEY----- diff --git a/src/python/grpcio_test/grpc_test/credentials/server1.pem b/src/python/grpcio_test/grpc_test/credentials/server1.pem new file mode 100755 index 0000000000..8e582e571f --- /dev/null +++ b/src/python/grpcio_test/grpc_test/credentials/server1.pem @@ -0,0 +1,16 @@ +-----BEGIN CERTIFICATE----- +MIICmzCCAgSgAwIBAgIBAzANBgkqhkiG9w0BAQUFADBWMQswCQYDVQQGEwJBVTET +MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMQ8wDQYDVQQDDAZ0ZXN0Y2EwHhcNMTQwNzIyMDYwMDU3WhcNMjQwNzE5 +MDYwMDU3WjBkMQswCQYDVQQGEwJVUzERMA8GA1UECBMISWxsaW5vaXMxEDAOBgNV +BAcTB0NoaWNhZ28xFDASBgNVBAoTC0dvb2dsZSBJbmMuMRowGAYDVQQDFBEqLnRl +c3QuZ29vZ2xlLmNvbTCBnzANBgkqhkiG9w0BAQEFAAOBjQAwgYkCgYEA4cMVJygs +JUmlgMMzgdi0h1XoCR7+ww1pop04OMMyy7H/i0PJ2W6Y35+b4CM8QrkYeEafUGDO +RYX6yV/cHGGsD/x02ye6ey1UDtkGAD/mpDEx8YCrjAc1Vfvt8Fk6Cn1WVIxV/J30 +3xjBsFgByQ55RBp1OLZfVLo6AleBDSbcxaECAwEAAaNrMGkwCQYDVR0TBAIwADAL +BgNVHQ8EBAMCBeAwTwYDVR0RBEgwRoIQKi50ZXN0Lmdvb2dsZS5mcoIYd2F0ZXJ6 +b29pLnRlc3QuZ29vZ2xlLmJlghIqLnRlc3QueW91dHViZS5jb22HBMCoAQMwDQYJ +KoZIhvcNAQEFBQADgYEAM2Ii0LgTGbJ1j4oqX9bxVcxm+/R5Yf8oi0aZqTJlnLYS +wXcBykxTx181s7WyfJ49WwrYXo78zTDAnf1ma0fPq3e4mpspvyndLh1a+OarHa1e +aT0DIIYk7qeEa1YcVljx2KyLd0r1BBAfrwyGaEPVeJQVYWaOJRU2we/KD4ojf9s= +-----END CERTIFICATE----- diff --git a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py index 8804f3f223..b7dd5d4d17 100644 --- a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py +++ b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_blocking_invocation_inline_service.py @@ -73,6 +73,7 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): Overriding implementations must call this implementation. """ + self._invoker = None self.implementation.destantiate(self._memo) def testSuccessfulUnaryRequestUnaryResponse(self): diff --git a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_event_invocation_synchronous_event_service.py b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_event_invocation_synchronous_event_service.py index 5a78b4bed2..7cb273bf78 100644 --- a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_event_invocation_synchronous_event_service.py +++ b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_event_invocation_synchronous_event_service.py @@ -74,6 +74,7 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): Overriding implementations must call this implementation. """ + self._invoker = None self.implementation.destantiate(self._memo) def testSuccessfulUnaryRequestUnaryResponse(self): diff --git a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_future_invocation_asynchronous_event_service.py b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_future_invocation_asynchronous_event_service.py index d1107e1576..272a37f15f 100644 --- a/src/python/grpcio_test/grpc_test/framework/interfaces/face/_future_invocation_asynchronous_event_service.py +++ b/src/python/grpcio_test/grpc_test/framework/interfaces/face/_future_invocation_asynchronous_event_service.py @@ -103,6 +103,7 @@ class TestCase(test_coverage.Coverage, unittest.TestCase): Overriding implementations must call this implementation. """ + self._invoker = None self.implementation.destantiate(self._memo) self._digest_pool.shutdown(wait=True) diff --git a/src/python/grpcio_test/grpc_test/resources.py b/src/python/grpcio_test/grpc_test/resources.py new file mode 100644 index 0000000000..2c3045313d --- /dev/null +++ b/src/python/grpcio_test/grpc_test/resources.py @@ -0,0 +1,56 @@ +# Copyright 2015, Google Inc. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Constants and functions for data used in interoperability testing.""" + +import os + +import pkg_resources + +_ROOT_CERTIFICATES_RESOURCE_PATH = 'credentials/ca.pem' +_PRIVATE_KEY_RESOURCE_PATH = 'credentials/server1.key' +_CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem' + + +def test_root_certificates(): + return pkg_resources.resource_string( + __name__, _ROOT_CERTIFICATES_RESOURCE_PATH) + + +def prod_root_certificates(): + return open(os.environ['SSL_CERT_FILE'], mode='rb').read() + + +def private_key(): + return pkg_resources.resource_string(__name__, _PRIVATE_KEY_RESOURCE_PATH) + + +def certificate_chain(): + return pkg_resources.resource_string( + __name__, _CERTIFICATE_CHAIN_RESOURCE_PATH) diff --git a/src/python/grpcio_test/setup.py b/src/python/grpcio_test/setup.py index 898ea204ac..802dd1e53a 100644 --- a/src/python/grpcio_test/setup.py +++ b/src/python/grpcio_test/setup.py @@ -55,6 +55,11 @@ _PACKAGE_DATA = { 'grpc_protoc_plugin': [ 'test.proto', ], + 'grpc_test': [ + 'credentials/ca.pem', + 'credentials/server1.key', + 'credentials/server1.pem', + ], } _SETUP_REQUIRES = ( diff --git a/src/ruby/README.md b/src/ruby/README.md index f8902e34c5..7f75c0e313 100644 --- a/src/ruby/README.md +++ b/src/ruby/README.md @@ -19,10 +19,10 @@ INSTALLATION **Linux (Debian):** -Add [Debian unstable][] to your `sources.list` file. Example: +Add [Debian testing][] to your `sources.list` file. Example: ```sh -echo "deb http://ftp.us.debian.org/debian unstable main contrib non-free" | \ +echo "deb http://ftp.us.debian.org/debian testing main contrib non-free" | \ sudo tee -a /etc/apt/sources.list ``` @@ -99,4 +99,4 @@ Directory structure is the layout for [ruby extensions][] [ruby extensions]:http://guides.rubygems.org/gems-with-extensions/ [rubydoc]: http://www.rubydoc.info/gems/grpc [grpc.io]: http://www.grpc.io/docs/installation/ruby.html -[Debian unstable]:https://www.debian.org/releases/sid/ +[Debian testing]:https://www.debian.org/releases/stretch/ diff --git a/src/ruby/bin/math_client.rb b/src/ruby/bin/math_client.rb index 6319cda309..0ebd26f780 100755 --- a/src/ruby/bin/math_client.rb +++ b/src/ruby/bin/math_client.rb @@ -50,7 +50,7 @@ def do_div(stub) GRPC.logger.info('----------------') req = Math::DivArgs.new(dividend: 7, divisor: 3) GRPC.logger.info("div(7/3): req=#{req.inspect}") - resp = stub.div(req, INFINITE_FUTURE) + resp = stub.div(req, timeout: INFINITE_FUTURE) GRPC.logger.info("Answer: #{resp.inspect}") GRPC.logger.info('----------------') end @@ -71,7 +71,7 @@ def do_fib(stub) GRPC.logger.info('----------------') req = Math::FibArgs.new(limit: 11) GRPC.logger.info("fib(11): req=#{req.inspect}") - resp = stub.fib(req, INFINITE_FUTURE) + resp = stub.fib(req, timeout: INFINITE_FUTURE) resp.each do |r| GRPC.logger.info("Answer: #{r.inspect}") end @@ -86,7 +86,7 @@ def do_div_many(stub) reqs << Math::DivArgs.new(dividend: 5, divisor: 2) reqs << Math::DivArgs.new(dividend: 7, divisor: 2) GRPC.logger.info("div(7/3), div(5/2), div(7/2): reqs=#{reqs.inspect}") - resp = stub.div_many(reqs, 10) + resp = stub.div_many(reqs, timeout: INFINITE_FUTURE) resp.each do |r| GRPC.logger.info("Answer: #{r.inspect}") end diff --git a/src/ruby/bin/math_server.rb b/src/ruby/bin/math_server.rb index b41ccf6ce1..562f197317 100755 --- a/src/ruby/bin/math_server.rb +++ b/src/ruby/bin/math_server.rb @@ -41,9 +41,25 @@ $LOAD_PATH.unshift(this_dir) unless $LOAD_PATH.include?(this_dir) require 'forwardable' require 'grpc' +require 'logger' require 'math_services' require 'optparse' +# RubyLogger defines a logger for gRPC based on the standard ruby logger. +module RubyLogger + def logger + LOGGER + end + + LOGGER = Logger.new(STDOUT) +end + +# GRPC is the general RPC module +module GRPC + # Inject the noop #logger if no module-level logger method has been injected. + extend RubyLogger +end + # Holds state for a fibonacci series class Fibber def initialize(limit) @@ -155,7 +171,8 @@ end def test_server_creds certs = load_test_certs - GRPC::Core::ServerCredentials.new(nil, certs[1], certs[2]) + GRPC::Core::ServerCredentials.new( + nil, [{ private_key: certs[1], cert_chain: certs[2] }], false) end def main diff --git a/src/ruby/bin/noproto_server.rb b/src/ruby/bin/noproto_server.rb index 90baaf9a2e..72a5762040 100755 --- a/src/ruby/bin/noproto_server.rb +++ b/src/ruby/bin/noproto_server.rb @@ -77,7 +77,8 @@ end def test_server_creds certs = load_test_certs - GRPC::Core::ServerCredentials.new(nil, certs[1], certs[2]) + GRPC::Core::ServerCredentials.new( + nil, [{ private_key: certs[1], cert_chain: certs[2] }], false) end def main diff --git a/src/ruby/ext/grpc/rb_credentials.c b/src/ruby/ext/grpc/rb_credentials.c index ac3804df6f..ae757f6986 100644 --- a/src/ruby/ext/grpc/rb_credentials.c +++ b/src/ruby/ext/grpc/rb_credentials.c @@ -154,7 +154,7 @@ static VALUE grpc_rb_default_credentials_create(VALUE cls) { Creates the default credential instances. */ static VALUE grpc_rb_compute_engine_credentials_create(VALUE cls) { grpc_rb_credentials *wrapper = ALLOC(grpc_rb_credentials); - wrapper->wrapped = grpc_compute_engine_credentials_create(NULL); + wrapper->wrapped = grpc_google_compute_engine_credentials_create(NULL); if (wrapper->wrapped == NULL) { rb_raise(rb_eRuntimeError, "could not create composite engine credentials, not sure why"); diff --git a/src/ruby/ext/grpc/rb_server.c b/src/ruby/ext/grpc/rb_server.c index 7e76349d2e..4469658869 100644 --- a/src/ruby/ext/grpc/rb_server.c +++ b/src/ruby/ext/grpc/rb_server.c @@ -49,6 +49,9 @@ static VALUE grpc_rb_cServer = Qnil; /* id_at is the constructor method of the ruby standard Time class. */ static ID id_at; +/* id_insecure_server is used to indicate that a server is insecure */ +static VALUE id_insecure_server; + /* grpc_rb_server wraps a grpc_server. It provides a peer ruby object, 'mark' to minimize copying when a server is created from ruby. */ typedef struct grpc_rb_server { @@ -234,6 +237,7 @@ static VALUE grpc_rb_server_request_call(VALUE self, VALUE cqueue, grpc_call_error_detail_of(err), err); return Qnil; } + ev = grpc_rb_completion_queue_pluck_event(cqueue, tag_new, timeout); if (ev.type == GRPC_QUEUE_TIMEOUT) { grpc_request_call_stack_cleanup(&st); @@ -298,43 +302,22 @@ static VALUE grpc_rb_server_destroy(int argc, VALUE *argv, VALUE self) { if (s->wrapped != NULL) { grpc_server_shutdown_and_notify(s->wrapped, cq, NULL); ev = grpc_rb_completion_queue_pluck_event(cqueue, Qnil, timeout); - if (!ev.success) { - rb_warn("server shutdown failed, there will be a LEAKED object warning"); - return Qnil; - /* - TODO: renable the rb_raise below. - - At the moment if the timeout is INFINITE_FUTURE as recommended, the - pluck blocks forever, even though - - the outstanding server_request_calls correctly fail on the other - thread that they are running on. - - it's almost as if calls that fail on the other thread do not get - cleaned up by shutdown request, even though it caused htem to - terminate. - - rb_raise(rb_eRuntimeError, "grpc server shutdown did not succeed"); - return Qnil; - - The workaround is just to use a timeout and return without really - shutting down the server, and rely on the grpc core garbage collection - it down as a 'LEAKED OBJECT'. - - */ + rb_warn("server shutdown failed, cancelling the calls, objects may leak"); + grpc_server_cancel_all_calls(s->wrapped); + return Qfalse; } grpc_server_destroy(s->wrapped); s->wrapped = NULL; } - return Qnil; + return Qtrue; } /* call-seq: // insecure port insecure_server = Server.new(cq, {'arg1': 'value1'}) - insecure_server.add_http2_port('mydomain:50051') + insecure_server.add_http2_port('mydomain:50051', :this_port_is_insecure) // secure port server_creds = ... @@ -342,21 +325,22 @@ static VALUE grpc_rb_server_destroy(int argc, VALUE *argv, VALUE self) { secure_server.add_http_port('mydomain:50051', server_creds) Adds a http2 port to server */ -static VALUE grpc_rb_server_add_http2_port(int argc, VALUE *argv, VALUE self) { - VALUE port = Qnil; - VALUE rb_creds = Qnil; +static VALUE grpc_rb_server_add_http2_port(VALUE self, VALUE port, + VALUE rb_creds) { grpc_rb_server *s = NULL; grpc_server_credentials *creds = NULL; int recvd_port = 0; - /* "11" == 1 mandatory args, 1 (rb_creds) is optional */ - rb_scan_args(argc, argv, "11", &port, &rb_creds); - TypedData_Get_Struct(self, grpc_rb_server, &grpc_rb_server_data_type, s); if (s->wrapped == NULL) { rb_raise(rb_eRuntimeError, "destroyed!"); return Qnil; - } else if (rb_creds == Qnil) { + } else if (TYPE(rb_creds) == T_SYMBOL) { + if (id_insecure_server != SYM2ID(rb_creds)) { + rb_raise(rb_eTypeError, + "bad creds symbol, want :this_port_is_insecure"); + return Qnil; + } recvd_port = grpc_server_add_insecure_http2_port(s->wrapped, StringValueCStr(port)); if (recvd_port == 0) { @@ -398,8 +382,9 @@ void Init_grpc_server() { rb_define_alias(grpc_rb_cServer, "close", "destroy"); rb_define_method(grpc_rb_cServer, "add_http2_port", grpc_rb_server_add_http2_port, - -1); + 2); id_at = rb_intern("at"); + id_insecure_server = rb_intern("this_port_is_insecure"); } /* Gets the wrapped server from the ruby wrapper */ diff --git a/src/ruby/ext/grpc/rb_server_credentials.c b/src/ruby/ext/grpc/rb_server_credentials.c index 6af4c86c45..ea4d0d864e 100644 --- a/src/ruby/ext/grpc/rb_server_credentials.c +++ b/src/ruby/ext/grpc/rb_server_credentials.c @@ -135,63 +135,117 @@ static VALUE grpc_rb_server_credentials_init_copy(VALUE copy, VALUE orig) { return copy; } -/* The attribute used on the mark object to hold the pem_root_certs. */ +/* The attribute used on the mark object to preserve the pem_root_certs. */ static ID id_pem_root_certs; -/* The attribute used on the mark object to hold the pem_private_key. */ -static ID id_pem_private_key; +/* The attribute used on the mark object to preserve the pem_key_certs */ +static ID id_pem_key_certs; -/* The attribute used on the mark object to hold the pem_private_key. */ -static ID id_pem_cert_chain; +/* The key used to access the pem cert in a key_cert pair hash */ +static VALUE sym_cert_chain; + +/* The key used to access the pem private key in a key_cert pair hash */ +static VALUE sym_private_key; /* call-seq: - creds = ServerCredentials.new(pem_root_certs, pem_private_key, - pem_cert_chain) - creds = ServerCredentials.new(nil, pem_private_key, - pem_cert_chain) - - pem_root_certs: (required) PEM encoding of the server root certificate - pem_private_key: (optional) PEM encoding of the server's private key - pem_cert_chain: (optional) PEM encoding of the server's cert chain + creds = ServerCredentials.new(nil, + [{private_key: <pem_private_key1>, + {cert_chain: <pem_cert_chain1>}], + force_client_auth) + creds = ServerCredentials.new(pem_root_certs, + [{private_key: <pem_private_key1>, + {cert_chain: <pem_cert_chain1>}], + force_client_auth) + + pem_root_certs: (optional) PEM encoding of the server root certificate + pem_private_key: (required) PEM encoding of the server's private keys + force_client_auth: indicatees Initializes ServerCredential instances. */ static VALUE grpc_rb_server_credentials_init(VALUE self, VALUE pem_root_certs, - VALUE pem_private_key, - VALUE pem_cert_chain) { - /* TODO support multiple key cert pairs in the ruby API. */ + VALUE pem_key_certs, + VALUE force_client_auth) { grpc_rb_server_credentials *wrapper = NULL; grpc_server_credentials *creds = NULL; - grpc_ssl_pem_key_cert_pair key_cert_pair = {NULL, NULL}; - TypedData_Get_Struct(self, grpc_rb_server_credentials, - &grpc_rb_server_credentials_data_type, wrapper); - if (pem_cert_chain == Qnil) { - rb_raise(rb_eRuntimeError, - "could not create a server credential: nil pem_cert_chain"); + grpc_ssl_pem_key_cert_pair *key_cert_pairs = NULL; + VALUE cert = Qnil; + VALUE key = Qnil; + VALUE key_cert = Qnil; + int auth_client = 0; + int num_key_certs = 0; + int i; + + if (NIL_P(force_client_auth) || + !(force_client_auth == Qfalse || force_client_auth == Qtrue)) { + rb_raise(rb_eTypeError, + "bad force_client_auth: got:<%s> want: <True|False|nil>", + rb_obj_classname(force_client_auth)); return Qnil; - } else if (pem_private_key == Qnil) { - rb_raise(rb_eRuntimeError, - "could not create a server credential: nil pem_private_key"); + } + if (NIL_P(pem_key_certs) || TYPE(pem_key_certs) != T_ARRAY) { + rb_raise(rb_eTypeError, "bad pem_key_certs: got:<%s> want: <Array>", + rb_obj_classname(pem_key_certs)); + return Qnil; + } + num_key_certs = RARRAY_LEN(pem_key_certs); + if (num_key_certs == 0) { + rb_raise(rb_eTypeError, "bad pem_key_certs: it had no elements"); return Qnil; } - key_cert_pair.private_key = RSTRING_PTR(pem_private_key); - key_cert_pair.cert_chain = RSTRING_PTR(pem_cert_chain); - /* TODO Add a force_client_auth parameter and pass it here. */ + for (i = 0; i < num_key_certs; i++) { + key_cert = rb_ary_entry(pem_key_certs, i); + if (key_cert == Qnil) { + rb_raise(rb_eTypeError, + "could not create a server credential: nil key_cert"); + return Qnil; + } else if (TYPE(key_cert) != T_HASH) { + rb_raise(rb_eTypeError, + "could not create a server credential: want <Hash>, got <%s>", + rb_obj_classname(key_cert)); + return Qnil; + } else if (rb_hash_aref(key_cert, sym_private_key) == Qnil) { + rb_raise(rb_eTypeError, + "could not create a server credential: want nil private key"); + return Qnil; + } else if (rb_hash_aref(key_cert, sym_cert_chain) == Qnil) { + rb_raise(rb_eTypeError, + "could not create a server credential: want nil cert chain"); + return Qnil; + } + } + + auth_client = TYPE(force_client_auth) == T_TRUE; + key_cert_pairs = ALLOC_N(grpc_ssl_pem_key_cert_pair, num_key_certs); + for (i = 0; i < num_key_certs; i++) { + key_cert = rb_ary_entry(pem_key_certs, i); + key = rb_hash_aref(key_cert, sym_private_key); + cert = rb_hash_aref(key_cert, sym_cert_chain); + key_cert_pairs[i].private_key = RSTRING_PTR(key); + key_cert_pairs[i].cert_chain = RSTRING_PTR(cert); + } + + TypedData_Get_Struct(self, grpc_rb_server_credentials, + &grpc_rb_server_credentials_data_type, wrapper); + if (pem_root_certs == Qnil) { - creds = - grpc_ssl_server_credentials_create(NULL, &key_cert_pair, 1, 0, NULL); + creds = grpc_ssl_server_credentials_create(NULL, key_cert_pairs, + num_key_certs, + auth_client, NULL); } else { creds = grpc_ssl_server_credentials_create(RSTRING_PTR(pem_root_certs), - &key_cert_pair, 1, 0, NULL); + key_cert_pairs, num_key_certs, + auth_client, NULL); } + xfree(key_cert_pairs); if (creds == NULL) { rb_raise(rb_eRuntimeError, "could not create a credentials, not sure why"); + return Qnil; } wrapper->wrapped = creds; /* Add the input objects as hidden fields to preserve them. */ - rb_ivar_set(self, id_pem_cert_chain, pem_cert_chain); - rb_ivar_set(self, id_pem_private_key, pem_private_key); + rb_ivar_set(self, id_pem_key_certs, pem_key_certs); rb_ivar_set(self, id_pem_root_certs, pem_root_certs); return self; @@ -211,9 +265,10 @@ void Init_grpc_server_credentials() { rb_define_method(grpc_rb_cServerCredentials, "initialize_copy", grpc_rb_server_credentials_init_copy, 1); - id_pem_cert_chain = rb_intern("__pem_cert_chain"); - id_pem_private_key = rb_intern("__pem_private_key"); + id_pem_key_certs = rb_intern("__pem_key_certs"); id_pem_root_certs = rb_intern("__pem_root_certs"); + sym_private_key = ID2SYM(rb_intern("private_key")); + sym_cert_chain = ID2SYM(rb_intern("cert_chain")); } /* Gets the wrapped grpc_server_credentials from the ruby wrapper */ diff --git a/src/ruby/lib/grpc/generic/rpc_server.rb b/src/ruby/lib/grpc/generic/rpc_server.rb index 67bf35ce02..38ea333413 100644 --- a/src/ruby/lib/grpc/generic/rpc_server.rb +++ b/src/ruby/lib/grpc/generic/rpc_server.rb @@ -277,10 +277,11 @@ module GRPC @stop_mutex.synchronize do @stopped = true end - @pool.stop deadline = from_relative_time(@poll_period) - + return if @server.close(@cq, deadline) + deadline = from_relative_time(@poll_period) @server.close(@cq, deadline) + @pool.stop end # determines if the server has been stopped @@ -383,7 +384,6 @@ module GRPC @pool.start @server.start loop_handle_server_calls - @running = false end # Sends UNAVAILABLE if there are too many unprocessed jobs @@ -414,14 +414,13 @@ module GRPC fail 'not running' unless @running loop_tag = Object.new until stopped? - deadline = from_relative_time(@poll_period) begin - an_rpc = @server.request_call(@cq, loop_tag, deadline) + an_rpc = @server.request_call(@cq, loop_tag, INFINITE_FUTURE) c = new_active_server_call(an_rpc) rescue Core::CallError, RuntimeError => e # these might happen for various reasonse. The correct behaviour of - # the server is to log them and continue. - GRPC.logger.warn("server call failed: #{e}") + # the server is to log them and continue, if it's not shutting down. + GRPC.logger.warn("server call failed: #{e}") unless stopped? next end unless c.nil? @@ -431,6 +430,8 @@ module GRPC end end end + @running = false + GRPC.logger.info("stopped: #{self}") end def new_active_server_call(an_rpc) diff --git a/src/ruby/pb/test/server.rb b/src/ruby/pb/test/server.rb index e2e1ecbd62..a311bb76e6 100755 --- a/src/ruby/pb/test/server.rb +++ b/src/ruby/pb/test/server.rb @@ -64,7 +64,8 @@ end # creates a ServerCredentials from the test certificates. def test_server_creds certs = load_test_certs - GRPC::Core::ServerCredentials.new(nil, certs[1], certs[2]) + GRPC::Core::ServerCredentials.new( + nil, [{private_key: certs[1], cert_chain: certs[2]}], false) end # produces a string of null chars (\0) of length l. diff --git a/src/ruby/spec/client_server_spec.rb b/src/ruby/spec/client_server_spec.rb index 2e673ff413..387f2baec2 100644 --- a/src/ruby/spec/client_server_spec.rb +++ b/src/ruby/spec/client_server_spec.rb @@ -32,12 +32,6 @@ require 'spec_helper' include GRPC::Core -def load_test_certs - test_root = File.join(File.dirname(__FILE__), 'testdata') - files = ['ca.pem', 'server1.key', 'server1.pem'] - files.map { |f| File.open(File.join(test_root, f)).read } -end - shared_context 'setup: tags' do let(:sent_message) { 'sent message' } let(:reply_text) { 'the reply' } @@ -402,7 +396,7 @@ describe 'the http client/server' do @client_queue = GRPC::Core::CompletionQueue.new @server_queue = GRPC::Core::CompletionQueue.new @server = GRPC::Core::Server.new(@server_queue, nil) - server_port = @server.add_http2_port(server_host) + server_port = @server.add_http2_port(server_host, :this_port_is_insecure) @server.start @ch = Channel.new("0.0.0.0:#{server_port}", nil) end @@ -420,12 +414,19 @@ describe 'the http client/server' do end describe 'the secure http client/server' do + def load_test_certs + test_root = File.join(File.dirname(__FILE__), 'testdata') + files = ['ca.pem', 'server1.key', 'server1.pem'] + files.map { |f| File.open(File.join(test_root, f)).read } + end + before(:example) do certs = load_test_certs server_host = '0.0.0.0:0' @client_queue = GRPC::Core::CompletionQueue.new @server_queue = GRPC::Core::CompletionQueue.new - server_creds = GRPC::Core::ServerCredentials.new(nil, certs[1], certs[2]) + server_creds = GRPC::Core::ServerCredentials.new( + nil, [{ private_key: certs[1], cert_chain: certs[2] }], false) @server = GRPC::Core::Server.new(@server_queue, nil) server_port = @server.add_http2_port(server_host, server_creds) @server.start diff --git a/src/ruby/spec/credentials_spec.rb b/src/ruby/spec/credentials_spec.rb index 8e72e85d54..b02219dfdb 100644 --- a/src/ruby/spec/credentials_spec.rb +++ b/src/ruby/spec/credentials_spec.rb @@ -29,15 +29,15 @@ require 'grpc' -def load_test_certs - test_root = File.join(File.dirname(__FILE__), 'testdata') - files = ['ca.pem', 'server1.pem', 'server1.key'] - files.map { |f| File.open(File.join(test_root, f)).read } -end +describe GRPC::Core::Credentials do + Credentials = GRPC::Core::Credentials -Credentials = GRPC::Core::Credentials + def load_test_certs + test_root = File.join(File.dirname(__FILE__), 'testdata') + files = ['ca.pem', 'server1.pem', 'server1.key'] + files.map { |f| File.open(File.join(test_root, f)).read } + end -describe Credentials do describe '#new' do it 'can be constructed with fake inputs' do expect { Credentials.new('root_certs', 'key', 'cert') }.not_to raise_error diff --git a/src/ruby/spec/generic/active_call_spec.rb b/src/ruby/spec/generic/active_call_spec.rb index fcd7bd082f..b05e3284fe 100644 --- a/src/ruby/spec/generic/active_call_spec.rb +++ b/src/ruby/spec/generic/active_call_spec.rb @@ -46,7 +46,7 @@ describe GRPC::ActiveCall do @server_queue = GRPC::Core::CompletionQueue.new host = '0.0.0.0:0' @server = GRPC::Core::Server.new(@server_queue, nil) - server_port = @server.add_http2_port(host) + server_port = @server.add_http2_port(host, :this_port_is_insecure) @server.start @ch = GRPC::Core::Channel.new("0.0.0.0:#{server_port}", nil) end diff --git a/src/ruby/spec/generic/client_stub_spec.rb b/src/ruby/spec/generic/client_stub_spec.rb index edcc962a7d..a05433df75 100644 --- a/src/ruby/spec/generic/client_stub_spec.rb +++ b/src/ruby/spec/generic/client_stub_spec.rb @@ -498,7 +498,7 @@ describe 'ClientStub' do def create_test_server @server_queue = GRPC::Core::CompletionQueue.new @server = GRPC::Core::Server.new(@server_queue, nil) - @server.add_http2_port('0.0.0.0:0') + @server.add_http2_port('0.0.0.0:0', :this_port_is_insecure) end def expect_server_to_be_invoked(notifier) diff --git a/src/ruby/spec/generic/rpc_server_spec.rb b/src/ruby/spec/generic/rpc_server_spec.rb index 1295fd7fdd..e484a9ea50 100644 --- a/src/ruby/spec/generic/rpc_server_spec.rb +++ b/src/ruby/spec/generic/rpc_server_spec.rb @@ -139,7 +139,7 @@ describe GRPC::RpcServer do @server_queue = GRPC::Core::CompletionQueue.new server_host = '0.0.0.0:0' @server = GRPC::Core::Server.new(@server_queue, nil) - server_port = @server.add_http2_port(server_host) + server_port = @server.add_http2_port(server_host, :this_port_is_insecure) @host = "localhost:#{server_port}" @ch = GRPC::Core::Channel.new(@host, nil) end diff --git a/src/ruby/spec/pb/health/checker_spec.rb b/src/ruby/spec/pb/health/checker_spec.rb index 6999a69105..d7b7535cbe 100644 --- a/src/ruby/spec/pb/health/checker_spec.rb +++ b/src/ruby/spec/pb/health/checker_spec.rb @@ -186,7 +186,7 @@ describe Grpc::Health::Checker do @server_queue = GRPC::Core::CompletionQueue.new server_host = '0.0.0.0:0' @server = GRPC::Core::Server.new(@server_queue, nil) - server_port = @server.add_http2_port(server_host) + server_port = @server.add_http2_port(server_host, :this_port_is_insecure) @host = "localhost:#{server_port}" @ch = GRPC::Core::Channel.new(@host, nil) @client_opts = { channel_override: @ch } diff --git a/src/ruby/spec/server_credentials_spec.rb b/src/ruby/spec/server_credentials_spec.rb index 55598bc8df..8ae577009d 100644 --- a/src/ruby/spec/server_credentials_spec.rb +++ b/src/ruby/spec/server_credentials_spec.rb @@ -31,8 +31,9 @@ require 'grpc' def load_test_certs test_root = File.join(File.dirname(__FILE__), 'testdata') - files = ['ca.pem', 'server1.pem', 'server1.key'] - files.map { |f| File.open(File.join(test_root, f)).read } + files = ['ca.pem', 'server1.key', 'server1.pem'] + contents = files.map { |f| File.open(File.join(test_root, f)).read } + [contents[0], [{ private_key: contents[1], cert_chain: contents[2] }], false] end describe GRPC::Core::ServerCredentials do @@ -40,7 +41,8 @@ describe GRPC::Core::ServerCredentials do describe '#new' do it 'can be constructed from a fake CA PEM, server PEM and a server key' do - expect { Creds.new('a', 'b', 'c') }.not_to raise_error + creds = Creds.new('a', [{ private_key: 'a', cert_chain: 'b' }], false) + expect(creds).to_not be_nil end it 'can be constructed using the test certificates' do @@ -48,21 +50,44 @@ describe GRPC::Core::ServerCredentials do expect { Creds.new(*certs) }.not_to raise_error end + it 'cannot be constructed without a nil key_cert pair array' do + root_cert, _, _ = load_test_certs + blk = proc do + Creds.new(root_cert, nil, false) + end + expect(&blk).to raise_error + end + + it 'cannot be constructed without any key_cert pairs' do + root_cert, _, _ = load_test_certs + blk = proc do + Creds.new(root_cert, [], false) + end + expect(&blk).to raise_error + end + it 'cannot be constructed without a server cert chain' do root_cert, server_key, _ = load_test_certs - blk = proc { Creds.new(root_cert, server_key, nil) } + blk = proc do + Creds.new(root_cert, + [{ server_key: server_key, cert_chain: nil }], + false) + end expect(&blk).to raise_error end it 'cannot be constructed without a server key' do root_cert, _, _ = load_test_certs - blk = proc { Creds.new(root_cert, nil, cert_chain) } + blk = proc do + Creds.new(root_cert, + [{ server_key: nil, cert_chain: cert_chain }]) + end expect(&blk).to raise_error end it 'can be constructed without a root_cret' do - _, server_key, cert_chain = load_test_certs - blk = proc { Creds.new(nil, server_key, cert_chain) } + _, cert_pairs, _ = load_test_certs + blk = proc { Creds.new(nil, cert_pairs, false) } expect(&blk).to_not raise_error end end diff --git a/src/ruby/spec/server_spec.rb b/src/ruby/spec/server_spec.rb index 47fe575343..439b19fb8d 100644 --- a/src/ruby/spec/server_spec.rb +++ b/src/ruby/spec/server_spec.rb @@ -32,7 +32,8 @@ require 'grpc' def load_test_certs test_root = File.join(File.dirname(__FILE__), 'testdata') files = ['ca.pem', 'server1.key', 'server1.pem'] - files.map { |f| File.open(File.join(test_root, f)).read } + contents = files.map { |f| File.open(File.join(test_root, f)).read } + [contents[0], [{ private_key: contents[1], cert_chain: contents[2] }], false] end Server = GRPC::Core::Server @@ -104,7 +105,7 @@ describe Server do it 'runs without failing' do blk = proc do s = Server.new(@cq, nil) - s.add_http2_port('localhost:0') + s.add_http2_port('localhost:0', :this_port_is_insecure) s.close(@cq) end expect(&blk).to_not raise_error @@ -113,7 +114,10 @@ describe Server do it 'fails if the server is closed' do s = Server.new(@cq, nil) s.close(@cq) - expect { s.add_http2_port('localhost:0') }.to raise_error(RuntimeError) + blk = proc do + s.add_http2_port('localhost:0', :this_port_is_insecure) + end + expect(&blk).to raise_error(RuntimeError) end end @@ -198,7 +202,7 @@ describe Server do def start_a_server s = Server.new(@cq, nil) - s.add_http2_port('0.0.0.0:0') + s.add_http2_port('0.0.0.0:0', :this_port_is_insecure) s.start s end |