diff options
Diffstat (limited to 'src/python/grpcio_tests/tests')
-rw-r--r-- | src/python/grpcio_tests/tests/_loader.py | 11 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/_result.py | 41 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/interop/methods.py | 4 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/interop/server.py | 7 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/tests.json | 1 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_auth_context_test.py | 45 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py | 13 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py | 164 | ||||
-rw-r--r-- | src/python/grpcio_tests/tests/unit/_session_cache_test.py | 145 |
9 files changed, 218 insertions, 213 deletions
diff --git a/src/python/grpcio_tests/tests/_loader.py b/src/python/grpcio_tests/tests/_loader.py index be0af64646..80c107aa8e 100644 --- a/src/python/grpcio_tests/tests/_loader.py +++ b/src/python/grpcio_tests/tests/_loader.py @@ -48,12 +48,13 @@ class Loader(object): # measure unnecessarily suffers) coverage_context = coverage.Coverage(data_suffix=True) coverage_context.start() - modules = [importlib.import_module(name) for name in names] - for module in modules: - self.visit_module(module) - for module in modules: + imported_modules = tuple( + importlib.import_module(name) for name in names) + for imported_module in imported_modules: + self.visit_module(imported_module) + for imported_module in imported_modules: try: - package_paths = module.__path__ + package_paths = imported_module.__path__ except AttributeError: continue self.walk_packages(package_paths) diff --git a/src/python/grpcio_tests/tests/_result.py b/src/python/grpcio_tests/tests/_result.py index b105f18e78..e5378b7ea3 100644 --- a/src/python/grpcio_tests/tests/_result.py +++ b/src/python/grpcio_tests/tests/_result.py @@ -144,10 +144,6 @@ class AugmentedResult(unittest.TestResult): super(AugmentedResult, self).startTestRun() self.cases = dict() - def stopTestRun(self): - """See unittest.TestResult.stopTestRun.""" - super(AugmentedResult, self).stopTestRun() - def startTest(self, test): """See unittest.TestResult.startTest.""" super(AugmentedResult, self).startTest(test) @@ -155,19 +151,19 @@ class AugmentedResult(unittest.TestResult): self.cases[case_id] = CaseResult( id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING) - def addError(self, test, error): + def addError(self, test, err): """See unittest.TestResult.addError.""" - super(AugmentedResult, self).addError(test, error) + super(AugmentedResult, self).addError(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.ERROR, traceback=error) + kind=CaseResult.Kind.ERROR, traceback=err) - def addFailure(self, test, error): + def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" - super(AugmentedResult, self).addFailure(test, error) + super(AugmentedResult, self).addFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.FAILURE, traceback=error) + kind=CaseResult.Kind.FAILURE, traceback=err) def addSuccess(self, test): """See unittest.TestResult.addSuccess.""" @@ -183,12 +179,12 @@ class AugmentedResult(unittest.TestResult): self.cases[case_id] = self.cases[case_id].updated( kind=CaseResult.Kind.SKIP, skip_reason=reason) - def addExpectedFailure(self, test, error): + def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" - super(AugmentedResult, self).addExpectedFailure(test, error) + super(AugmentedResult, self).addExpectedFailure(test, err) case_id = self.id_map(test) self.cases[case_id] = self.cases[case_id].updated( - kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=error) + kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=err) def addUnexpectedSuccess(self, test): """See unittest.TestResult.addUnexpectedSuccess.""" @@ -249,13 +245,6 @@ class CoverageResult(AugmentedResult): self.coverage_context.save() self.coverage_context = None - def stopTestRun(self): - """See unittest.TestResult.stopTestRun.""" - super(CoverageResult, self).stopTestRun() - # TODO(atash): Dig deeper into why the following line fails to properly - # combine coverage data from the Cython plugin. - #coverage.Coverage().combine() - class _Colors(object): """Namespaced constants for terminal color magic numbers.""" @@ -295,16 +284,16 @@ class TerminalResult(CoverageResult): self.out.write(summary(self)) self.out.flush() - def addError(self, test, error): + def addError(self, test, err): """See unittest.TestResult.addError.""" - super(TerminalResult, self).addError(test, error) + super(TerminalResult, self).addError(test, err) self.out.write( _Colors.FAIL + 'ERROR {}\n'.format(test.id()) + _Colors.END) self.out.flush() - def addFailure(self, test, error): + def addFailure(self, test, err): """See unittest.TestResult.addFailure.""" - super(TerminalResult, self).addFailure(test, error) + super(TerminalResult, self).addFailure(test, err) self.out.write( _Colors.FAIL + 'FAILURE {}\n'.format(test.id()) + _Colors.END) self.out.flush() @@ -323,9 +312,9 @@ class TerminalResult(CoverageResult): _Colors.INFO + 'SKIP {}\n'.format(test.id()) + _Colors.END) self.out.flush() - def addExpectedFailure(self, test, error): + def addExpectedFailure(self, test, err): """See unittest.TestResult.addExpectedFailure.""" - super(TerminalResult, self).addExpectedFailure(test, error) + super(TerminalResult, self).addExpectedFailure(test, err) self.out.write( _Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) + _Colors.END) self.out.flush() diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py index b728ffd704..cda15a68a3 100644 --- a/src/python/grpcio_tests/tests/interop/methods.py +++ b/src/python/grpcio_tests/tests/interop/methods.py @@ -144,8 +144,8 @@ def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope, def _empty_unary(stub): response = stub.EmptyCall(empty_pb2.Empty()) if not isinstance(response, empty_pb2.Empty): - raise TypeError('response is of type "%s", not empty_pb2.Empty!', - type(response)) + raise TypeError( + 'response is of type "%s", not empty_pb2.Empty!' % type(response)) def _large_unary(stub): diff --git a/src/python/grpcio_tests/tests/interop/server.py b/src/python/grpcio_tests/tests/interop/server.py index 0810de2394..fd28d498a1 100644 --- a/src/python/grpcio_tests/tests/interop/server.py +++ b/src/python/grpcio_tests/tests/interop/server.py @@ -26,6 +26,7 @@ from tests.interop import resources from tests.unit import test_common _ONE_DAY_IN_SECONDS = 60 * 60 * 24 +_LOGGER = logging.getLogger(__name__) def serve(): @@ -52,14 +53,14 @@ def serve(): server.add_insecure_port('[::]:{}'.format(args.port)) server.start() - logging.info('Server serving.') + _LOGGER.info('Server serving.') try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except BaseException as e: - logging.info('Caught exception "%s"; stopping server...', e) + _LOGGER.info('Caught exception "%s"; stopping server...', e) server.stop(None) - logging.info('Server stopped; exiting.') + _LOGGER.info('Server stopped; exiting.') if __name__ == '__main__': diff --git a/src/python/grpcio_tests/tests/tests.json b/src/python/grpcio_tests/tests/tests.json index 0d94426413..65460a9540 100644 --- a/src/python/grpcio_tests/tests/tests.json +++ b/src/python/grpcio_tests/tests/tests.json @@ -53,6 +53,7 @@ "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestCertConfigReuse", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithClientAuth", "unit._server_ssl_cert_config_test.ServerSSLCertReloadTestWithoutClientAuth", + "unit._session_cache_test.SSLSessionCacheTest", "unit.beta._beta_features_test.BetaFeaturesTest", "unit.beta._beta_features_test.ContextManagementAndLifecycleTest", "unit.beta._connectivity_channel_test.ConnectivityStatesTest", diff --git a/src/python/grpcio_tests/tests/unit/_auth_context_test.py b/src/python/grpcio_tests/tests/unit/_auth_context_test.py index 8c1a30e032..d174051070 100644 --- a/src/python/grpcio_tests/tests/unit/_auth_context_test.py +++ b/src/python/grpcio_tests/tests/unit/_auth_context_test.py @@ -18,6 +18,7 @@ import unittest import grpc from grpc import _channel +from grpc.experimental import session_cache import six from tests.unit import test_common @@ -140,6 +141,50 @@ class AuthContextTest(unittest.TestCase): self.assertSequenceEqual([b'*.test.google.com'], auth_ctx['x509_common_name']) + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel( + 'localhost:{}'.format(port), channel_creds, options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSessionResumption(self): + # Set up a secure server + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + # Create a cache for TLS session tickets + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'false']) + + # Subsequent connections resume sessions + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port, + expect_ssl_session_reused=[b'true']) + server.stop(None) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py deleted file mode 100644 index 5fb4f3c3cf..0000000000 --- a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py deleted file mode 100644 index 2bf1e1cc0d..0000000000 --- a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2015 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(nathaniel): Remove this from source control after having made -# generation from the stock.proto source part of GRPC's build-and-test -# process. - -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: stock.proto - -import sys -_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1')) -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -from google.protobuf import descriptor_pb2 -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - -DESCRIPTOR = _descriptor.FileDescriptor( - name='stock.proto', - package='stock', - serialized_pb=_b( - '\n\x0bstock.proto\x12\x05stock\">\n\x0cStockRequest\x12\x0e\n\x06symbol\x18\x01 \x01(\t\x12\x1e\n\x13num_trades_to_watch\x18\x02 \x01(\x05:\x01\x30\"+\n\nStockReply\x12\r\n\x05price\x18\x01 \x01(\x02\x12\x0e\n\x06symbol\x18\x02 \x01(\t2\x96\x02\n\x05Stock\x12=\n\x11GetLastTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x12I\n\x19GetLastTradePriceMultiple\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01\x30\x01\x12?\n\x11WatchFutureTrades\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x30\x01\x12\x42\n\x14GetHighestTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01' - )) -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -_STOCKREQUEST = _descriptor.Descriptor( - name='StockRequest', - full_name='stock.StockRequest', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='symbol', - full_name='stock.StockRequest.symbol', - index=0, - number=1, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='num_trades_to_watch', - full_name='stock.StockRequest.num_trades_to_watch', - index=1, - number=2, - type=5, - cpp_type=1, - label=1, - has_default_value=True, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - extension_ranges=[], - oneofs=[], - serialized_start=22, - serialized_end=84,) - -_STOCKREPLY = _descriptor.Descriptor( - name='StockReply', - full_name='stock.StockReply', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='price', - full_name='stock.StockReply.price', - index=0, - number=1, - type=2, - cpp_type=6, - label=1, - has_default_value=False, - default_value=0, - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - _descriptor.FieldDescriptor( - name='symbol', - full_name='stock.StockReply.symbol', - index=1, - number=2, - type=9, - cpp_type=9, - label=1, - has_default_value=False, - default_value=_b("").decode('utf-8'), - message_type=None, - enum_type=None, - containing_type=None, - is_extension=False, - extension_scope=None, - options=None), - ], - extensions=[], - nested_types=[], - enum_types=[], - options=None, - is_extendable=False, - extension_ranges=[], - oneofs=[], - serialized_start=86, - serialized_end=129,) - -DESCRIPTOR.message_types_by_name['StockRequest'] = _STOCKREQUEST -DESCRIPTOR.message_types_by_name['StockReply'] = _STOCKREPLY - -StockRequest = _reflection.GeneratedProtocolMessageType( - 'StockRequest', - (_message.Message,), - dict( - DESCRIPTOR=_STOCKREQUEST, - __module__='stock_pb2' - # @@protoc_insertion_point(class_scope:stock.StockRequest) - )) -_sym_db.RegisterMessage(StockRequest) - -StockReply = _reflection.GeneratedProtocolMessageType( - 'StockReply', - (_message.Message,), - dict( - DESCRIPTOR=_STOCKREPLY, - __module__='stock_pb2' - # @@protoc_insertion_point(class_scope:stock.StockReply) - )) -_sym_db.RegisterMessage(StockReply) - -# @@protoc_insertion_point(module_scope) diff --git a/src/python/grpcio_tests/tests/unit/_session_cache_test.py b/src/python/grpcio_tests/tests/unit/_session_cache_test.py new file mode 100644 index 0000000000..b4e4670fa7 --- /dev/null +++ b/src/python/grpcio_tests/tests/unit/_session_cache_test.py @@ -0,0 +1,145 @@ +# Copyright 2018 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests experimental TLS Session Resumption API""" + +import pickle +import unittest + +import grpc +from grpc import _channel +from grpc.experimental import session_cache + +from tests.unit import test_common +from tests.unit import resources + +_REQUEST = b'\x00\x00\x00' +_RESPONSE = b'\x00\x00\x00' + +_UNARY_UNARY = '/test/UnaryUnary' + +_SERVER_HOST_OVERRIDE = 'foo.test.google.fr' +_ID = 'id' +_ID_KEY = 'id_key' +_AUTH_CTX = 'auth_ctx' + +_PRIVATE_KEY = resources.private_key() +_CERTIFICATE_CHAIN = resources.certificate_chain() +_TEST_ROOT_CERTIFICATES = resources.test_root_certificates() +_SERVER_CERTS = ((_PRIVATE_KEY, _CERTIFICATE_CHAIN),) +_PROPERTY_OPTIONS = (( + 'grpc.ssl_target_name_override', + _SERVER_HOST_OVERRIDE, +),) + + +def handle_unary_unary(request, servicer_context): + return pickle.dumps({ + _ID: servicer_context.peer_identities(), + _ID_KEY: servicer_context.peer_identity_key(), + _AUTH_CTX: servicer_context.auth_context() + }) + + +def start_secure_server(): + handler = grpc.method_handlers_generic_handler('test', { + 'UnaryUnary': + grpc.unary_unary_rpc_method_handler(handle_unary_unary) + }) + server = test_common.test_server() + server.add_generic_rpc_handlers((handler,)) + server_cred = grpc.ssl_server_credentials(_SERVER_CERTS) + port = server.add_secure_port('[::]:0', server_cred) + server.start() + + return server, port + + +class SSLSessionCacheTest(unittest.TestCase): + + def _do_one_shot_client_rpc(self, channel_creds, channel_options, port, + expect_ssl_session_reused): + channel = grpc.secure_channel( + 'localhost:{}'.format(port), channel_creds, options=channel_options) + response = channel.unary_unary(_UNARY_UNARY)(_REQUEST) + auth_data = pickle.loads(response) + self.assertEqual(expect_ssl_session_reused, + auth_data[_AUTH_CTX]['ssl_session_reused']) + channel.close() + + def testSSLSessionCacheLRU(self): + server_1, port_1 = start_secure_server() + + cache = session_cache.ssl_session_cache_lru(1) + channel_creds = grpc.ssl_channel_credentials( + root_certificates=_TEST_ROOT_CERTIFICATES) + channel_options = _PROPERTY_OPTIONS + ( + ('grpc.ssl_session_cache', cache),) + + # Initial connection has no session to resume + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Connection to server_1 resumes from initial session + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + + # Connection to a different server with the same name overwrites the cache entry + server_2, port_2 = start_secure_server() + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'false']) + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_2, + expect_ssl_session_reused=[b'true']) + server_2.stop(None) + + # Connection to server_1 now falls back to full TLS handshake + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Re-creating server_1 causes old sessions to become invalid + server_1.stop(None) + server_1, port_1 = start_secure_server() + + # Old sessions should no longer be valid + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'false']) + + # Resumption should work for subsequent connections + self._do_one_shot_client_rpc( + channel_creds, + channel_options, + port_1, + expect_ssl_session_reused=[b'true']) + server_1.stop(None) + + +if __name__ == '__main__': + unittest.main(verbosity=2) |