From db13f68dab0e01acb9c0186eec423a0419326fd4 Mon Sep 17 00:00:00 2001 From: Nathaniel Manista Date: Sat, 7 Mar 2015 00:18:13 +0000 Subject: Add a server_host_override to stub creation This optional value should only be passed in tests. --- src/python/src/grpc/_adapter/_c_test.py | 3 ++- src/python/src/grpc/_adapter/_channel.c | 28 +++++++++++++++++----- src/python/src/grpc/_adapter/rear.py | 23 +++++++++++++----- .../src/grpc/early_adopter/implementations.py | 7 ++++-- 4 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/python/src/grpc/_adapter/_c_test.py b/src/python/src/grpc/_adapter/_c_test.py index d81c63e346..e11835746b 100644 --- a/src/python/src/grpc/_adapter/_c_test.py +++ b/src/python/src/grpc/_adapter/_c_test.py @@ -70,7 +70,8 @@ class _CTest(unittest.TestCase): def testChannel(self): _c.init() - channel = _c.Channel('test host:12345', None) + channel = _c.Channel( + 'test host:12345', None, server_host_override='ignored') del channel _c.shut_down() diff --git a/src/python/src/grpc/_adapter/_channel.c b/src/python/src/grpc/_adapter/_channel.c index 9cf580bcfb..6be8f1c364 100644 --- a/src/python/src/grpc/_adapter/_channel.c +++ b/src/python/src/grpc/_adapter/_channel.c @@ -42,19 +42,35 @@ static int pygrpc_channel_init(Channel *self, PyObject *args, PyObject *kwds) { const char *hostport; PyObject *client_credentials; - static char *kwlist[] = {"hostport", "client_credentials", NULL}; + char *server_host_override = NULL; + static char *kwlist[] = {"hostport", "client_credentials", + "server_host_override", NULL}; + grpc_arg server_host_override_arg; + grpc_channel_args channel_args; - if (!(PyArg_ParseTupleAndKeywords(args, kwds, "sO:Channel", kwlist, - &hostport, &client_credentials))) { + if (!(PyArg_ParseTupleAndKeywords(args, kwds, "sO|z:Channel", kwlist, + &hostport, &client_credentials, + &server_host_override))) { return -1; } if (client_credentials == Py_None) { self->c_channel = grpc_channel_create(hostport, NULL); return 0; } else { - self->c_channel = grpc_secure_channel_create( - ((ClientCredentials *)client_credentials)->c_client_credentials, - hostport, NULL); + if (server_host_override == NULL) { + self->c_channel = grpc_secure_channel_create( + ((ClientCredentials *)client_credentials)->c_client_credentials, + hostport, NULL); + } else { + server_host_override_arg.type = GRPC_ARG_STRING; + server_host_override_arg.key = GRPC_SSL_TARGET_NAME_OVERRIDE_ARG; + server_host_override_arg.value.string = server_host_override; + channel_args.num_args = 1; + channel_args.args = &server_host_override_arg; + self->c_channel = grpc_secure_channel_create( + ((ClientCredentials *)client_credentials)->c_client_credentials, + hostport, &channel_args); + } return 0; } } diff --git a/src/python/src/grpc/_adapter/rear.py b/src/python/src/grpc/_adapter/rear.py index bfde5f5c57..fc71bf0a6c 100644 --- a/src/python/src/grpc/_adapter/rear.py +++ b/src/python/src/grpc/_adapter/rear.py @@ -93,7 +93,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): def __init__( self, host, port, pool, request_serializers, response_deserializers, - secure, root_certificates, private_key, certificate_chain): + secure, root_certificates, private_key, certificate_chain, + server_host_override=None): """Constructor. Args: @@ -111,6 +112,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): key should be used. certificate_chain: The PEM-encoded certificate chain to use or None if no certificate chain should be used. + server_host_override: (For testing only) the target name used for SSL + host name checking. """ self._condition = threading.Condition() self._host = host @@ -132,6 +135,7 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): self._root_certificates = root_certificates self._private_key = private_key self._certificate_chain = certificate_chain + self._server_host_override = server_host_override def _on_write_event(self, operation_id, event, rpc_state): if event.write_accepted: @@ -327,7 +331,8 @@ class RearLink(ticket_interfaces.RearLink, activated.Activated): with self._condition: self._completion_queue = _low.CompletionQueue() self._channel = _low.Channel( - '%s:%d' % (self._host, self._port), self._client_credentials) + '%s:%d' % (self._host, self._port), self._client_credentials, + server_host_override=self._server_host_override) return self def _stop(self): @@ -388,7 +393,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): def __init__( self, host, port, request_serializers, response_deserializers, secure, - root_certificates, private_key, certificate_chain): + root_certificates, private_key, certificate_chain, + server_host_override=None): self._host = host self._port = port self._request_serializers = request_serializers @@ -397,6 +403,7 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): self._root_certificates = root_certificates self._private_key = private_key self._certificate_chain = certificate_chain + self._server_host_override = server_host_override self._lock = threading.Lock() self._pool = None @@ -415,7 +422,8 @@ class _ActivatedRearLink(ticket_interfaces.RearLink, activated.Activated): self._rear_link = RearLink( self._host, self._port, self._pool, self._request_serializers, self._response_deserializers, self._secure, self._root_certificates, - self._private_key, self._certificate_chain) + self._private_key, self._certificate_chain, + server_host_override=self._server_host_override) self._rear_link.join_fore_link(self._fore_link) self._rear_link.start() return self @@ -477,7 +485,7 @@ def activated_rear_link( def secure_activated_rear_link( host, port, request_serializers, response_deserializers, root_certificates, - private_key, certificate_chain): + private_key, certificate_chain, server_host_override=None): """Creates a RearLink that is also an activated.Activated. The returned object is only valid for use between calls to its start and stop @@ -496,7 +504,10 @@ def secure_activated_rear_link( should be used. certificate_chain: The PEM-encoded certificate chain to use or None if no certificate chain should be used. + server_host_override: (For testing only) the target name used for SSL + host name checking. """ return _ActivatedRearLink( host, port, request_serializers, response_deserializers, True, - root_certificates, private_key, certificate_chain) + root_certificates, private_key, certificate_chain, + server_host_override=server_host_override) diff --git a/src/python/src/grpc/early_adopter/implementations.py b/src/python/src/grpc/early_adopter/implementations.py index 6195958624..87ea18d666 100644 --- a/src/python/src/grpc/early_adopter/implementations.py +++ b/src/python/src/grpc/early_adopter/implementations.py @@ -125,7 +125,8 @@ def insecure_stub(methods, host, port): def secure_stub( - methods, host, port, root_certificates, private_key, certificate_chain): + methods, host, port, root_certificates, private_key, certificate_chain, + server_host_override=None): """Constructs an insecure interfaces.Stub. Args: @@ -140,6 +141,8 @@ def secure_stub( should be used. certificate_chain: The PEM-encoded certificate chain to use or None if no certificate chain should be used. + server_host_override: (For testing only) the target name used for SSL + host name checking. Returns: An interfaces.Stub affording RPC invocation. @@ -148,7 +151,7 @@ def secure_stub( activated_rear_link = _rear.secure_activated_rear_link( host, port, breakdown.request_serializers, breakdown.response_deserializers, root_certificates, private_key, - certificate_chain) + certificate_chain, server_host_override=server_host_override) return _build_stub(breakdown, activated_rear_link) -- cgit v1.2.3