diff options
author | Nathaniel Manista <nathaniel@google.com> | 2015-02-15 01:18:41 +0000 |
---|---|---|
committer | Nathaniel Manista <nathaniel@google.com> | 2015-02-15 01:18:41 +0000 |
commit | 81337bb41f705a854568e4c5cdf213393c3500b4 (patch) | |
tree | ec8cc283cd6f4b364670608362730d919ad96247 | |
parent | a73f7464c266622f853f4bd91868e4de57e5025c (diff) |
Add security to fore.ForeLink.
-rw-r--r-- | src/python/src/_adapter/_face_test_case.py | 2 | ||||
-rw-r--r-- | src/python/src/_adapter/_links_test.py | 6 | ||||
-rw-r--r-- | src/python/src/_adapter/_low.py | 1 | ||||
-rw-r--r-- | src/python/src/_adapter/_server.c | 15 | ||||
-rw-r--r-- | src/python/src/_adapter/fore.py | 21 |
5 files changed, 37 insertions, 8 deletions
diff --git a/src/python/src/_adapter/_face_test_case.py b/src/python/src/_adapter/_face_test_case.py index 112dcfb928..2c6e6286b5 100644 --- a/src/python/src/_adapter/_face_test_case.py +++ b/src/python/src/_adapter/_face_test_case.py @@ -80,7 +80,7 @@ class FaceTestCase(test_case.FaceTestCase, coverage.BlockingCoverage): fore_link = fore.ForeLink( pool, serialization.request_deserializers, - serialization.response_serializers) + serialization.response_serializers, None, ()) port = fore_link.start() rear_link = rear.RearLink( 'localhost', port, pool, diff --git a/src/python/src/_adapter/_links_test.py b/src/python/src/_adapter/_links_test.py index 8341460a9a..d8bbb27127 100644 --- a/src/python/src/_adapter/_links_test.py +++ b/src/python/src/_adapter/_links_test.py @@ -67,7 +67,7 @@ class RoundTripTest(unittest.TestCase): test_rear_link = _test_links.RearLink(rear_action, None) fore_link = fore.ForeLink( - self.fore_link_pool, {test_method: None}, {test_method: None}) + self.fore_link_pool, {test_method: None}, {test_method: None}, None, ()) fore_link.join_rear_link(test_rear_link) test_rear_link.join_fore_link(fore_link) port = fore_link.start() @@ -120,7 +120,7 @@ class RoundTripTest(unittest.TestCase): fore_link = fore.ForeLink( self.fore_link_pool, {test_method: _IDENTITY}, - {test_method: _IDENTITY}) + {test_method: _IDENTITY}, None, ()) fore_link.join_rear_link(test_rear_link) test_rear_link.join_fore_link(fore_link) port = fore_link.start() @@ -182,7 +182,7 @@ class RoundTripTest(unittest.TestCase): fore_link = fore.ForeLink( self.fore_link_pool, {test_method: scenario.deserialize_request}, - {test_method: scenario.serialize_response}) + {test_method: scenario.serialize_response}, None, ()) fore_link.join_rear_link(test_rear_link) test_rear_link.join_fore_link(fore_link) port = fore_link.start() diff --git a/src/python/src/_adapter/_low.py b/src/python/src/_adapter/_low.py index 6c24087dad..09105eafa0 100644 --- a/src/python/src/_adapter/_low.py +++ b/src/python/src/_adapter/_low.py @@ -52,4 +52,5 @@ Call = _c.Call Channel = _c.Channel CompletionQueue = _c.CompletionQueue Server = _c.Server +ServerCredentials = _c.ServerCredentials # pylint: enable=invalid-name diff --git a/src/python/src/_adapter/_server.c b/src/python/src/_adapter/_server.c index 503be61ab4..2f8cc99e44 100644 --- a/src/python/src/_adapter/_server.c +++ b/src/python/src/_adapter/_server.c @@ -85,6 +85,19 @@ static PyObject *pygrpc_server_add_http2_addr(Server *self, PyObject *args) { return PyInt_FromLong(port); } +static PyObject *pygrpc_server_add_secure_http2_addr(Server *self, + PyObject *args) { + const char *addr; + int port; + PyArg_ParseTuple(args, "s", &addr); + port = grpc_server_add_secure_http2_port(self->c_server, addr); + if (port == 0) { + PyErr_SetString(PyExc_RuntimeError, "Couldn't add port to server!"); + return NULL; + } + return PyInt_FromLong(port); +} + static PyObject *pygrpc_server_start(Server *self) { grpc_server_start(self->c_server); @@ -118,6 +131,8 @@ static PyObject *pygrpc_server_stop(Server *self) { static PyMethodDef methods[] = { {"add_http2_addr", (PyCFunction)pygrpc_server_add_http2_addr, METH_VARARGS, "Add an HTTP2 address."}, + {"add_secure_http2_addr", (PyCFunction)pygrpc_server_add_secure_http2_addr, + METH_VARARGS, "Add a secure HTTP2 address."}, {"start", (PyCFunction)pygrpc_server_start, METH_NOARGS, "Starts the server."}, {"service", (PyCFunction)pygrpc_server_service, METH_VARARGS, diff --git a/src/python/src/_adapter/fore.py b/src/python/src/_adapter/fore.py index 2f102751f2..28aede1fd9 100644 --- a/src/python/src/_adapter/fore.py +++ b/src/python/src/_adapter/fore.py @@ -69,7 +69,8 @@ class ForeLink(ticket_interfaces.ForeLink): """A service-side bridge between RPC Framework and the C-ish _low code.""" def __init__( - self, pool, request_deserializers, response_serializers, port=None): + self, pool, request_deserializers, response_serializers, + root_certificates, key_chain_pairs, port=None): """Constructor. Args: @@ -78,6 +79,10 @@ class ForeLink(ticket_interfaces.ForeLink): deserializer behaviors. response_serializers: A dict from RPC method names to response object serializer behaviors. + root_certificates: The PEM-encoded client root certificates as a + bytestring or None. + key_chain_pairs: A sequence of PEM-encoded private key-certificate chain + pairs. port: The port on which to serve, or None to have a port selected automatically. """ @@ -85,6 +90,8 @@ class ForeLink(ticket_interfaces.ForeLink): self._pool = pool self._request_deserializers = request_deserializers self._response_serializers = response_serializers + self._root_certificates = root_certificates + self._key_chain_pairs = key_chain_pairs self._port = port self._rear_link = null.NULL_REAR_LINK @@ -264,10 +271,16 @@ class ForeLink(ticket_interfaces.ForeLink): object. """ with self._condition: + address = '[::]:%d' % (0 if self._port is None else self._port) self._completion_queue = _low.CompletionQueue() - self._server = _low.Server(self._completion_queue, None) - port = self._server.add_http2_addr( - '[::]:%d' % (0 if self._port is None else self._port)) + if self._root_certificates is None and not self._key_chain_pairs: + self._server = _low.Server(self._completion_queue, None) + port = self._server.add_http2_addr(address) + else: + server_credentials = _low.ServerCredentials( + self._root_certificates, self._key_chain_pairs) + self._server = _low.Server(self._completion_queue, server_credentials) + port = self._server.add_secure_http2_addr(address) self._server.start() self._server.service(None) |