aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/python/grpcio_tests
diff options
context:
space:
mode:
authorGravatar Masood Malekghassemi <atash@google.com>2017-01-13 19:20:10 -0800
committerGravatar Masood Malekghassemi <atash@google.com>2017-01-17 10:55:33 -0800
commitcc793703bfba6f661f523b6fec82ff8a913e1759 (patch)
treef3cb0c7330565e9ed9947a07c6423f81e5c00f72 /src/python/grpcio_tests
parent06dea573daa2175b244a430bb89b49bb5c8e8c5b (diff)
Run Python formatting
Diffstat (limited to 'src/python/grpcio_tests')
-rw-r--r--src/python/grpcio_tests/commands.py266
-rw-r--r--src/python/grpcio_tests/setup.py42
-rw-r--r--src/python/grpcio_tests/tests/_loader.py89
-rw-r--r--src/python/grpcio_tests/tests/_result.py618
-rw-r--r--src/python/grpcio_tests/tests/_runner.py273
-rw-r--r--src/python/grpcio_tests/tests/health_check/_health_servicer_test.py87
-rw-r--r--src/python/grpcio_tests/tests/http2/_negative_http2_client.py175
-rw-r--r--src/python/grpcio_tests/tests/interop/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py24
-rw-r--r--src/python/grpcio_tests/tests/interop/_intraop_test_case.py38
-rw-r--r--src/python/grpcio_tests/tests/interop/_secure_intraop_test.py35
-rw-r--r--src/python/grpcio_tests/tests/interop/client.py163
-rw-r--r--src/python/grpcio_tests/tests/interop/methods.py759
-rw-r--r--src/python/grpcio_tests/tests/interop/resources.py21
-rw-r--r--src/python/grpcio_tests/tests/interop/server.py57
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py754
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py446
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py716
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/qps/benchmark_client.py280
-rw-r--r--src/python/grpcio_tests/tests/qps/benchmark_server.py32
-rw-r--r--src/python/grpcio_tests/tests/qps/client_runner.py107
-rw-r--r--src/python/grpcio_tests/tests/qps/histogram.py82
-rw-r--r--src/python/grpcio_tests/tests/qps/qps_worker.py34
-rw-r--r--src/python/grpcio_tests/tests/qps/worker_server.py297
-rw-r--r--src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py234
-rw-r--r--src/python/grpcio_tests/tests/stress/client.py217
-rw-r--r--src/python/grpcio_tests/tests/stress/metrics_server.py41
-rw-r--r--src/python/grpcio_tests/tests/stress/test_runner.py59
-rw-r--r--src/python/grpcio_tests/tests/unit/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_api_test.py117
-rw-r--r--src/python/grpcio_tests/tests/unit/_auth_test.py69
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_args_test.py20
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py209
-rw-r--r--src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py108
-rw-r--r--src/python/grpcio_tests/tests/unit/_compression_test.py147
-rw-r--r--src/python/grpcio_tests/tests/unit/_credentials_test.py56
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py302
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_channel_test.py55
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py402
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py710
-rw-r--r--src/python/grpcio_tests/tests/unit/_cython/test_utilities.py45
-rw-r--r--src/python/grpcio_tests/tests/unit/_empty_message_test.py125
-rw-r--r--src/python/grpcio_tests/tests/unit/_exit_scenarios.py269
-rw-r--r--src/python/grpcio_tests/tests/unit/_exit_test.py243
-rw-r--r--src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py223
-rw-r--r--src/python/grpcio_tests/tests/unit/_invocation_defects_test.py316
-rw-r--r--src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py209
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py901
-rw-r--r--src/python/grpcio_tests/tests/unit/_metadata_test.py254
-rw-r--r--src/python/grpcio_tests/tests/unit/_rpc_test.py1414
-rw-r--r--src/python/grpcio_tests/tests/unit/_sanity/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/_sanity/_sanity_test.py32
-rw-r--r--src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py143
-rw-r--r--src/python/grpcio_tests/tests/unit/_thread_pool.py25
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py528
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py15
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py160
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_implementations_test.py44
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_not_found_test.py58
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/_utilities_test.py113
-rw-r--r--src/python/grpcio_tests/tests/unit/beta/test_utilities.py17
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/common/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/common/test_constants.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/common/test_control.py97
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py117
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py64
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py55
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_3069_test_constant.py1
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/__init__.py2
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py465
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py550
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py876
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py187
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_service.py159
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py519
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py27
-rw-r--r--src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_interfaces.py144
-rw-r--r--src/python/grpcio_tests/tests/unit/resources.py11
-rw-r--r--src/python/grpcio_tests/tests/unit/test_common.py69
94 files changed, 8441 insertions, 8216 deletions
diff --git a/src/python/grpcio_tests/commands.py b/src/python/grpcio_tests/commands.py
index e822971fe0..845b7f598c 100644
--- a/src/python/grpcio_tests/commands.py
+++ b/src/python/grpcio_tests/commands.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Provides distutils command classes for the gRPC Python setup process."""
import distutils
@@ -55,163 +54,162 @@ PYTHON_PROTO_TOP_LEVEL = os.path.join(PYTHON_STEM, 'src')
class CommandError(object):
- pass
+ pass
class GatherProto(setuptools.Command):
- description = 'gather proto dependencies'
- user_options = []
+ description = 'gather proto dependencies'
+ user_options = []
- def initialize_options(self):
- pass
+ def initialize_options(self):
+ pass
- def finalize_options(self):
- pass
+ def finalize_options(self):
+ pass
- def run(self):
- # TODO(atash) ensure that we're running from the repository directory when
- # this command is used
- try:
- shutil.rmtree(PROTO_STEM)
- except Exception as error:
- # We don't care if this command fails
- pass
- shutil.copytree(GRPC_PROTO_STEM, PROTO_STEM)
- for root, _, _ in os.walk(PYTHON_PROTO_TOP_LEVEL):
- path = os.path.join(root, '__init__.py')
- open(path, 'a').close()
+ def run(self):
+ # TODO(atash) ensure that we're running from the repository directory when
+ # this command is used
+ try:
+ shutil.rmtree(PROTO_STEM)
+ except Exception as error:
+ # We don't care if this command fails
+ pass
+ shutil.copytree(GRPC_PROTO_STEM, PROTO_STEM)
+ for root, _, _ in os.walk(PYTHON_PROTO_TOP_LEVEL):
+ path = os.path.join(root, '__init__.py')
+ open(path, 'a').close()
class BuildProtoModules(setuptools.Command):
- """Command to generate project *_pb2.py modules from proto files."""
-
- description = 'build protobuf modules'
- user_options = [
- ('include=', None, 'path patterns to include in protobuf generation'),
- ('exclude=', None, 'path patterns to exclude from protobuf generation')
- ]
-
- def initialize_options(self):
- self.exclude = None
- self.include = r'.*\.proto$'
-
- def finalize_options(self):
- pass
-
- def run(self):
- import grpc_tools.protoc as protoc
-
- include_regex = re.compile(self.include)
- exclude_regex = re.compile(self.exclude) if self.exclude else None
- paths = []
- for walk_root, directories, filenames in os.walk(PROTO_STEM):
- for filename in filenames:
- path = os.path.join(walk_root, filename)
- if include_regex.match(path) and not (
- exclude_regex and exclude_regex.match(path)):
- paths.append(path)
-
- # TODO(kpayson): It would be nice to do this in a batch command,
- # but we currently have name conflicts in src/proto
- for path in paths:
- command = [
- 'grpc_tools.protoc',
- '-I {}'.format(PROTO_STEM),
- '--python_out={}'.format(PROTO_STEM),
- '--grpc_python_out={}'.format(PROTO_STEM),
- ] + [path]
- if protoc.main(command) != 0:
- sys.stderr.write(
- 'warning: Command:\n{}\nFailed'.format(
- command))
-
- # Generated proto directories dont include __init__.py, but
- # these are needed for python package resolution
- for walk_root, _, _ in os.walk(PROTO_STEM):
- path = os.path.join(walk_root, '__init__.py')
- open(path, 'a').close()
+ """Command to generate project *_pb2.py modules from proto files."""
+
+ description = 'build protobuf modules'
+ user_options = [
+ ('include=', None, 'path patterns to include in protobuf generation'),
+ ('exclude=', None, 'path patterns to exclude from protobuf generation')
+ ]
+
+ def initialize_options(self):
+ self.exclude = None
+ self.include = r'.*\.proto$'
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ import grpc_tools.protoc as protoc
+
+ include_regex = re.compile(self.include)
+ exclude_regex = re.compile(self.exclude) if self.exclude else None
+ paths = []
+ for walk_root, directories, filenames in os.walk(PROTO_STEM):
+ for filename in filenames:
+ path = os.path.join(walk_root, filename)
+ if include_regex.match(path) and not (
+ exclude_regex and exclude_regex.match(path)):
+ paths.append(path)
+
+ # TODO(kpayson): It would be nice to do this in a batch command,
+ # but we currently have name conflicts in src/proto
+ for path in paths:
+ command = [
+ 'grpc_tools.protoc',
+ '-I {}'.format(PROTO_STEM),
+ '--python_out={}'.format(PROTO_STEM),
+ '--grpc_python_out={}'.format(PROTO_STEM),
+ ] + [path]
+ if protoc.main(command) != 0:
+ sys.stderr.write('warning: Command:\n{}\nFailed'.format(
+ command))
+
+ # Generated proto directories dont include __init__.py, but
+ # these are needed for python package resolution
+ for walk_root, _, _ in os.walk(PROTO_STEM):
+ path = os.path.join(walk_root, '__init__.py')
+ open(path, 'a').close()
class BuildPy(build_py.build_py):
- """Custom project build command."""
+ """Custom project build command."""
- def run(self):
- try:
- self.run_command('build_package_protos')
- except CommandError as error:
- sys.stderr.write('warning: %s\n' % error.message)
- build_py.build_py.run(self)
+ def run(self):
+ try:
+ self.run_command('build_package_protos')
+ except CommandError as error:
+ sys.stderr.write('warning: %s\n' % error.message)
+ build_py.build_py.run(self)
class TestLite(setuptools.Command):
- """Command to run tests without fetching or building anything."""
+ """Command to run tests without fetching or building anything."""
- description = 'run tests without fetching or building anything.'
- user_options = []
+ description = 'run tests without fetching or building anything.'
+ user_options = []
- def initialize_options(self):
- pass
+ def initialize_options(self):
+ pass
- def finalize_options(self):
- # distutils requires this override.
- pass
+ def finalize_options(self):
+ # distutils requires this override.
+ pass
- def run(self):
- self._add_eggs_to_path()
+ def run(self):
+ self._add_eggs_to_path()
- import tests
- loader = tests.Loader()
- loader.loadTestsFromNames(['tests'])
- runner = tests.Runner()
- result = runner.run(loader.suite)
- if not result.wasSuccessful():
- sys.exit('Test failure')
+ import tests
+ loader = tests.Loader()
+ loader.loadTestsFromNames(['tests'])
+ runner = tests.Runner()
+ result = runner.run(loader.suite)
+ if not result.wasSuccessful():
+ sys.exit('Test failure')
- def _add_eggs_to_path(self):
- """Fetch install and test requirements"""
- self.distribution.fetch_build_eggs(self.distribution.install_requires)
- self.distribution.fetch_build_eggs(self.distribution.tests_require)
+ def _add_eggs_to_path(self):
+ """Fetch install and test requirements"""
+ self.distribution.fetch_build_eggs(self.distribution.install_requires)
+ self.distribution.fetch_build_eggs(self.distribution.tests_require)
class RunInterop(test.test):
- description = 'run interop test client/server'
- user_options = [
- ('args=', 'a', 'pass-thru arguments for the client/server'),
- ('client', 'c', 'flag indicating to run the client'),
- ('server', 's', 'flag indicating to run the server')
- ]
-
- def initialize_options(self):
- self.args = ''
- self.client = False
- self.server = False
-
- def finalize_options(self):
- if self.client and self.server:
- raise DistutilsOptionError('you may only specify one of client or server')
-
- def run(self):
- if self.distribution.install_requires:
- self.distribution.fetch_build_eggs(self.distribution.install_requires)
- if self.distribution.tests_require:
- self.distribution.fetch_build_eggs(self.distribution.tests_require)
- if self.client:
- self.run_client()
- elif self.server:
- self.run_server()
-
- def run_server(self):
- # We import here to ensure that our setuptools parent has had a chance to
- # edit the Python system path.
- from tests.interop import server
- sys.argv[1:] = self.args.split()
- server.serve()
-
- def run_client(self):
- # We import here to ensure that our setuptools parent has had a chance to
- # edit the Python system path.
- from tests.interop import client
- sys.argv[1:] = self.args.split()
- client.test_interoperability()
+ description = 'run interop test client/server'
+ user_options = [('args=', 'a', 'pass-thru arguments for the client/server'),
+ ('client', 'c', 'flag indicating to run the client'),
+ ('server', 's', 'flag indicating to run the server')]
+
+ def initialize_options(self):
+ self.args = ''
+ self.client = False
+ self.server = False
+
+ def finalize_options(self):
+ if self.client and self.server:
+ raise DistutilsOptionError(
+ 'you may only specify one of client or server')
+
+ def run(self):
+ if self.distribution.install_requires:
+ self.distribution.fetch_build_eggs(
+ self.distribution.install_requires)
+ if self.distribution.tests_require:
+ self.distribution.fetch_build_eggs(self.distribution.tests_require)
+ if self.client:
+ self.run_client()
+ elif self.server:
+ self.run_server()
+
+ def run_server(self):
+ # We import here to ensure that our setuptools parent has had a chance to
+ # edit the Python system path.
+ from tests.interop import server
+ sys.argv[1:] = self.args.split()
+ server.serve()
+
+ def run_client(self):
+ # We import here to ensure that our setuptools parent has had a chance to
+ # edit the Python system path.
+ from tests.interop import client
+ sys.argv[1:] = self.args.split()
+ client.test_interoperability()
diff --git a/src/python/grpcio_tests/setup.py b/src/python/grpcio_tests/setup.py
index 375fbd6c77..f0407d1a55 100644
--- a/src/python/grpcio_tests/setup.py
+++ b/src/python/grpcio_tests/setup.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""A setup module for the gRPC Python package."""
import os
@@ -48,9 +47,7 @@ import grpc_version
LICENSE = '3-clause BSD'
-PACKAGE_DIRECTORIES = {
- '': '.',
-}
+PACKAGE_DIRECTORIES = {'': '.',}
INSTALL_REQUIRES = (
'coverage>=4.0',
@@ -61,13 +58,11 @@ INSTALL_REQUIRES = (
'grpcio-health-checking>={version}'.format(version=grpc_version.VERSION),
'oauth2client>=1.4.7',
'protobuf>=3.0.0',
- 'six>=1.10',
-)
+ 'six>=1.10',)
COMMAND_CLASS = {
# Run `preprocess` *before* doing any packaging!
'preprocess': commands.GatherProto,
-
'build_package_protos': grpc_tools.command.BuildPackageProtos,
'build_py': commands.BuildPy,
'run_interop': commands.RunInterop,
@@ -80,9 +75,7 @@ PACKAGE_DATA = {
'credentials/server1.key',
'credentials/server1.pem',
],
- 'tests.protoc_plugin.protos.invocation_testing': [
- 'same.proto',
- ],
+ 'tests.protoc_plugin.protos.invocation_testing': ['same.proto',],
'tests.protoc_plugin.protos.invocation_testing.split_messages': [
'messages.proto',
],
@@ -94,9 +87,7 @@ PACKAGE_DATA = {
'credentials/server1.key',
'credentials/server1.pem',
],
- 'tests': [
- 'tests.json'
- ],
+ 'tests': ['tests.json'],
}
TEST_SUITE = 'tests'
@@ -107,16 +98,15 @@ TESTS_REQUIRE = INSTALL_REQUIRES
PACKAGES = setuptools.find_packages('.')
setuptools.setup(
- name='grpcio-tests',
- version=grpc_version.VERSION,
- license=LICENSE,
- packages=list(PACKAGES),
- package_dir=PACKAGE_DIRECTORIES,
- package_data=PACKAGE_DATA,
- install_requires=INSTALL_REQUIRES,
- cmdclass=COMMAND_CLASS,
- tests_require=TESTS_REQUIRE,
- test_suite=TEST_SUITE,
- test_loader=TEST_LOADER,
- test_runner=TEST_RUNNER,
-)
+ name='grpcio-tests',
+ version=grpc_version.VERSION,
+ license=LICENSE,
+ packages=list(PACKAGES),
+ package_dir=PACKAGE_DIRECTORIES,
+ package_data=PACKAGE_DATA,
+ install_requires=INSTALL_REQUIRES,
+ cmdclass=COMMAND_CLASS,
+ tests_require=TESTS_REQUIRE,
+ test_suite=TEST_SUITE,
+ test_loader=TEST_LOADER,
+ test_runner=TEST_RUNNER,)
diff --git a/src/python/grpcio_tests/tests/_loader.py b/src/python/grpcio_tests/tests/_loader.py
index 621bedc7bb..42cf9ab4ca 100644
--- a/src/python/grpcio_tests/tests/_loader.py
+++ b/src/python/grpcio_tests/tests/_loader.py
@@ -40,7 +40,7 @@ TEST_MODULE_REGEX = r'^.*_test$'
class Loader(object):
- """Test loader for setuptools test suite support.
+ """Test loader for setuptools test suite support.
Attributes:
suite (unittest.TestSuite): All tests collected by the loader.
@@ -51,57 +51,57 @@ class Loader(object):
contributes to the test suite.
"""
- def __init__(self):
- self.suite = unittest.TestSuite()
- self.loader = unittest.TestLoader()
- self.module_matcher = re.compile(TEST_MODULE_REGEX)
+ def __init__(self):
+ self.suite = unittest.TestSuite()
+ self.loader = unittest.TestLoader()
+ self.module_matcher = re.compile(TEST_MODULE_REGEX)
- def loadTestsFromNames(self, names, module=None):
- """Function mirroring TestLoader::loadTestsFromNames, as expected by
+ def loadTestsFromNames(self, names, module=None):
+ """Function mirroring TestLoader::loadTestsFromNames, as expected by
setuptools.setup argument `test_loader`."""
- # ensure that we capture decorators and definitions (else our coverage
- # measure unnecessarily suffers)
- coverage_context = coverage.Coverage(data_suffix=True)
- coverage_context.start()
- modules = [importlib.import_module(name) for name in names]
- for module in modules:
- self.visit_module(module)
- for module in modules:
- try:
- package_paths = module.__path__
- except:
- continue
- self.walk_packages(package_paths)
- coverage_context.stop()
- coverage_context.save()
- return self.suite
-
- def walk_packages(self, package_paths):
- """Walks over the packages, dispatching `visit_module` calls.
+ # ensure that we capture decorators and definitions (else our coverage
+ # measure unnecessarily suffers)
+ coverage_context = coverage.Coverage(data_suffix=True)
+ coverage_context.start()
+ modules = [importlib.import_module(name) for name in names]
+ for module in modules:
+ self.visit_module(module)
+ for module in modules:
+ try:
+ package_paths = module.__path__
+ except:
+ continue
+ self.walk_packages(package_paths)
+ coverage_context.stop()
+ coverage_context.save()
+ return self.suite
+
+ def walk_packages(self, package_paths):
+ """Walks over the packages, dispatching `visit_module` calls.
Args:
package_paths (list): A list of paths over which to walk through modules
along.
"""
- for importer, module_name, is_package in (
- pkgutil.walk_packages(package_paths)):
- module = importer.find_module(module_name).load_module(module_name)
- self.visit_module(module)
+ for importer, module_name, is_package in (
+ pkgutil.walk_packages(package_paths)):
+ module = importer.find_module(module_name).load_module(module_name)
+ self.visit_module(module)
- def visit_module(self, module):
- """Visits the module, adding discovered tests to the test suite.
+ def visit_module(self, module):
+ """Visits the module, adding discovered tests to the test suite.
Args:
module (module): Module to match against self.module_matcher; if matched
it has its tests loaded via self.loader into self.suite.
"""
- if self.module_matcher.match(module.__name__):
- module_suite = self.loader.loadTestsFromModule(module)
- self.suite.addTest(module_suite)
+ if self.module_matcher.match(module.__name__):
+ module_suite = self.loader.loadTestsFromModule(module)
+ self.suite.addTest(module_suite)
def iterate_suite_cases(suite):
- """Generator over all unittest.TestCases in a unittest.TestSuite.
+ """Generator over all unittest.TestCases in a unittest.TestSuite.
Args:
suite (unittest.TestSuite): Suite to iterate over in the generator.
@@ -109,11 +109,12 @@ def iterate_suite_cases(suite):
Returns:
generator: A generator over all unittest.TestCases in `suite`.
"""
- for item in suite:
- if isinstance(item, unittest.TestSuite):
- for child_item in iterate_suite_cases(item):
- yield child_item
- elif isinstance(item, unittest.TestCase):
- yield item
- else:
- raise ValueError('unexpected suite item of type {}'.format(type(item)))
+ for item in suite:
+ if isinstance(item, unittest.TestSuite):
+ for child_item in iterate_suite_cases(item):
+ yield child_item
+ elif isinstance(item, unittest.TestCase):
+ yield item
+ else:
+ raise ValueError('unexpected suite item of type {}'.format(
+ type(item)))
diff --git a/src/python/grpcio_tests/tests/_result.py b/src/python/grpcio_tests/tests/_result.py
index 1acec6a9b5..794b7540f1 100644
--- a/src/python/grpcio_tests/tests/_result.py
+++ b/src/python/grpcio_tests/tests/_result.py
@@ -41,9 +41,11 @@ from six import moves
from tests import _loader
-class CaseResult(collections.namedtuple('CaseResult', [
- 'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback'])):
- """A serializable result of a single test case.
+class CaseResult(
+ collections.namedtuple('CaseResult', [
+ 'id', 'name', 'kind', 'stdout', 'stderr', 'skip_reason', 'traceback'
+ ])):
+ """A serializable result of a single test case.
Attributes:
id (object): Any serializable object used to denote the identity of this
@@ -59,62 +61,78 @@ class CaseResult(collections.namedtuple('CaseResult', [
None.
"""
- class Kind:
- UNTESTED = 'untested'
- RUNNING = 'running'
- ERROR = 'error'
- FAILURE = 'failure'
- SUCCESS = 'success'
- SKIP = 'skip'
- EXPECTED_FAILURE = 'expected failure'
- UNEXPECTED_SUCCESS = 'unexpected success'
-
- def __new__(cls, id=None, name=None, kind=None, stdout=None, stderr=None,
- skip_reason=None, traceback=None):
- """Helper keyword constructor for the namedtuple.
+ class Kind:
+ UNTESTED = 'untested'
+ RUNNING = 'running'
+ ERROR = 'error'
+ FAILURE = 'failure'
+ SUCCESS = 'success'
+ SKIP = 'skip'
+ EXPECTED_FAILURE = 'expected failure'
+ UNEXPECTED_SUCCESS = 'unexpected success'
+
+ def __new__(cls,
+ id=None,
+ name=None,
+ kind=None,
+ stdout=None,
+ stderr=None,
+ skip_reason=None,
+ traceback=None):
+ """Helper keyword constructor for the namedtuple.
See this class' attributes for information on the arguments."""
- assert id is not None
- assert name is None or isinstance(name, str)
- if kind is CaseResult.Kind.UNTESTED:
- pass
- elif kind is CaseResult.Kind.RUNNING:
- pass
- elif kind is CaseResult.Kind.ERROR:
- assert traceback is not None
- elif kind is CaseResult.Kind.FAILURE:
- assert traceback is not None
- elif kind is CaseResult.Kind.SUCCESS:
- pass
- elif kind is CaseResult.Kind.SKIP:
- assert skip_reason is not None
- elif kind is CaseResult.Kind.EXPECTED_FAILURE:
- assert traceback is not None
- elif kind is CaseResult.Kind.UNEXPECTED_SUCCESS:
- pass
- else:
- assert False
- return super(cls, CaseResult).__new__(
- cls, id, name, kind, stdout, stderr, skip_reason, traceback)
-
- def updated(self, name=None, kind=None, stdout=None, stderr=None,
- skip_reason=None, traceback=None):
- """Get a new validated CaseResult with the fields updated.
+ assert id is not None
+ assert name is None or isinstance(name, str)
+ if kind is CaseResult.Kind.UNTESTED:
+ pass
+ elif kind is CaseResult.Kind.RUNNING:
+ pass
+ elif kind is CaseResult.Kind.ERROR:
+ assert traceback is not None
+ elif kind is CaseResult.Kind.FAILURE:
+ assert traceback is not None
+ elif kind is CaseResult.Kind.SUCCESS:
+ pass
+ elif kind is CaseResult.Kind.SKIP:
+ assert skip_reason is not None
+ elif kind is CaseResult.Kind.EXPECTED_FAILURE:
+ assert traceback is not None
+ elif kind is CaseResult.Kind.UNEXPECTED_SUCCESS:
+ pass
+ else:
+ assert False
+ return super(cls, CaseResult).__new__(cls, id, name, kind, stdout,
+ stderr, skip_reason, traceback)
+
+ def updated(self,
+ name=None,
+ kind=None,
+ stdout=None,
+ stderr=None,
+ skip_reason=None,
+ traceback=None):
+ """Get a new validated CaseResult with the fields updated.
See this class' attributes for information on the arguments."""
- name = self.name if name is None else name
- kind = self.kind if kind is None else kind
- stdout = self.stdout if stdout is None else stdout
- stderr = self.stderr if stderr is None else stderr
- skip_reason = self.skip_reason if skip_reason is None else skip_reason
- traceback = self.traceback if traceback is None else traceback
- return CaseResult(id=self.id, name=name, kind=kind, stdout=stdout,
- stderr=stderr, skip_reason=skip_reason,
- traceback=traceback)
+ name = self.name if name is None else name
+ kind = self.kind if kind is None else kind
+ stdout = self.stdout if stdout is None else stdout
+ stderr = self.stderr if stderr is None else stderr
+ skip_reason = self.skip_reason if skip_reason is None else skip_reason
+ traceback = self.traceback if traceback is None else traceback
+ return CaseResult(
+ id=self.id,
+ name=name,
+ kind=kind,
+ stdout=stdout,
+ stderr=stderr,
+ skip_reason=skip_reason,
+ traceback=traceback)
class AugmentedResult(unittest.TestResult):
- """unittest.Result that keeps track of additional information.
+ """unittest.Result that keeps track of additional information.
Uses CaseResult objects to store test-case results, providing additional
information beyond that of the standard Python unittest library, such as
@@ -127,228 +145,215 @@ class AugmentedResult(unittest.TestResult):
to CaseResult objects corresponding to those IDs.
"""
- def __init__(self, id_map):
- """Initialize the object with an identifier mapping.
+ def __init__(self, id_map):
+ """Initialize the object with an identifier mapping.
Arguments:
id_map (callable): Corresponds to the attribute `id_map`."""
- super(AugmentedResult, self).__init__()
- self.id_map = id_map
- self.cases = None
-
- def startTestRun(self):
- """See unittest.TestResult.startTestRun."""
- super(AugmentedResult, self).startTestRun()
- self.cases = dict()
-
- def stopTestRun(self):
- """See unittest.TestResult.stopTestRun."""
- super(AugmentedResult, self).stopTestRun()
-
- def startTest(self, test):
- """See unittest.TestResult.startTest."""
- super(AugmentedResult, self).startTest(test)
- case_id = self.id_map(test)
- self.cases[case_id] = CaseResult(
- id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING)
-
- def addError(self, test, error):
- """See unittest.TestResult.addError."""
- super(AugmentedResult, self).addError(test, error)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.ERROR, traceback=error)
-
- def addFailure(self, test, error):
- """See unittest.TestResult.addFailure."""
- super(AugmentedResult, self).addFailure(test, error)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.FAILURE, traceback=error)
-
- def addSuccess(self, test):
- """See unittest.TestResult.addSuccess."""
- super(AugmentedResult, self).addSuccess(test)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.SUCCESS)
-
- def addSkip(self, test, reason):
- """See unittest.TestResult.addSkip."""
- super(AugmentedResult, self).addSkip(test, reason)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.SKIP, skip_reason=reason)
-
- def addExpectedFailure(self, test, error):
- """See unittest.TestResult.addExpectedFailure."""
- super(AugmentedResult, self).addExpectedFailure(test, error)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=error)
-
- def addUnexpectedSuccess(self, test):
- """See unittest.TestResult.addUnexpectedSuccess."""
- super(AugmentedResult, self).addUnexpectedSuccess(test)
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- kind=CaseResult.Kind.UNEXPECTED_SUCCESS)
-
- def set_output(self, test, stdout, stderr):
- """Set the output attributes for the CaseResult corresponding to a test.
+ super(AugmentedResult, self).__init__()
+ self.id_map = id_map
+ self.cases = None
+
+ def startTestRun(self):
+ """See unittest.TestResult.startTestRun."""
+ super(AugmentedResult, self).startTestRun()
+ self.cases = dict()
+
+ def stopTestRun(self):
+ """See unittest.TestResult.stopTestRun."""
+ super(AugmentedResult, self).stopTestRun()
+
+ def startTest(self, test):
+ """See unittest.TestResult.startTest."""
+ super(AugmentedResult, self).startTest(test)
+ case_id = self.id_map(test)
+ self.cases[case_id] = CaseResult(
+ id=case_id, name=test.id(), kind=CaseResult.Kind.RUNNING)
+
+ def addError(self, test, error):
+ """See unittest.TestResult.addError."""
+ super(AugmentedResult, self).addError(test, error)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.ERROR, traceback=error)
+
+ def addFailure(self, test, error):
+ """See unittest.TestResult.addFailure."""
+ super(AugmentedResult, self).addFailure(test, error)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.FAILURE, traceback=error)
+
+ def addSuccess(self, test):
+ """See unittest.TestResult.addSuccess."""
+ super(AugmentedResult, self).addSuccess(test)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.SUCCESS)
+
+ def addSkip(self, test, reason):
+ """See unittest.TestResult.addSkip."""
+ super(AugmentedResult, self).addSkip(test, reason)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.SKIP, skip_reason=reason)
+
+ def addExpectedFailure(self, test, error):
+ """See unittest.TestResult.addExpectedFailure."""
+ super(AugmentedResult, self).addExpectedFailure(test, error)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.EXPECTED_FAILURE, traceback=error)
+
+ def addUnexpectedSuccess(self, test):
+ """See unittest.TestResult.addUnexpectedSuccess."""
+ super(AugmentedResult, self).addUnexpectedSuccess(test)
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ kind=CaseResult.Kind.UNEXPECTED_SUCCESS)
+
+ def set_output(self, test, stdout, stderr):
+ """Set the output attributes for the CaseResult corresponding to a test.
Args:
test (unittest.TestCase): The TestCase to set the outputs of.
stdout (str): Output from stdout to assign to self.id_map(test).
stderr (str): Output from stderr to assign to self.id_map(test).
"""
- case_id = self.id_map(test)
- self.cases[case_id] = self.cases[case_id].updated(
- stdout=stdout.decode(), stderr=stderr.decode())
+ case_id = self.id_map(test)
+ self.cases[case_id] = self.cases[case_id].updated(
+ stdout=stdout.decode(), stderr=stderr.decode())
- def augmented_results(self, filter):
- """Convenience method to retrieve filtered case results.
+ def augmented_results(self, filter):
+ """Convenience method to retrieve filtered case results.
Args:
filter (callable): A unary predicate to filter over CaseResult objects.
"""
- return (self.cases[case_id] for case_id in self.cases
- if filter(self.cases[case_id]))
+ return (self.cases[case_id] for case_id in self.cases
+ if filter(self.cases[case_id]))
class CoverageResult(AugmentedResult):
- """Extension to AugmentedResult adding coverage.py support per test.\
+ """Extension to AugmentedResult adding coverage.py support per test.\
Attributes:
coverage_context (coverage.Coverage): coverage.py management object.
"""
- def __init__(self, id_map):
- """See AugmentedResult.__init__."""
- super(CoverageResult, self).__init__(id_map=id_map)
- self.coverage_context = None
+ def __init__(self, id_map):
+ """See AugmentedResult.__init__."""
+ super(CoverageResult, self).__init__(id_map=id_map)
+ self.coverage_context = None
- def startTest(self, test):
- """See unittest.TestResult.startTest.
+ def startTest(self, test):
+ """See unittest.TestResult.startTest.
Additionally initializes and begins code coverage tracking."""
- super(CoverageResult, self).startTest(test)
- self.coverage_context = coverage.Coverage(data_suffix=True)
- self.coverage_context.start()
+ super(CoverageResult, self).startTest(test)
+ self.coverage_context = coverage.Coverage(data_suffix=True)
+ self.coverage_context.start()
- def stopTest(self, test):
- """See unittest.TestResult.stopTest.
+ def stopTest(self, test):
+ """See unittest.TestResult.stopTest.
Additionally stops and deinitializes code coverage tracking."""
- super(CoverageResult, self).stopTest(test)
- self.coverage_context.stop()
- self.coverage_context.save()
- self.coverage_context = None
+ super(CoverageResult, self).stopTest(test)
+ self.coverage_context.stop()
+ self.coverage_context.save()
+ self.coverage_context = None
- def stopTestRun(self):
- """See unittest.TestResult.stopTestRun."""
- super(CoverageResult, self).stopTestRun()
- # TODO(atash): Dig deeper into why the following line fails to properly
- # combine coverage data from the Cython plugin.
- #coverage.Coverage().combine()
+ def stopTestRun(self):
+ """See unittest.TestResult.stopTestRun."""
+ super(CoverageResult, self).stopTestRun()
+ # TODO(atash): Dig deeper into why the following line fails to properly
+ # combine coverage data from the Cython plugin.
+ #coverage.Coverage().combine()
class _Colors:
- """Namespaced constants for terminal color magic numbers."""
- HEADER = '\033[95m'
- INFO = '\033[94m'
- OK = '\033[92m'
- WARN = '\033[93m'
- FAIL = '\033[91m'
- BOLD = '\033[1m'
- UNDERLINE = '\033[4m'
- END = '\033[0m'
+ """Namespaced constants for terminal color magic numbers."""
+ HEADER = '\033[95m'
+ INFO = '\033[94m'
+ OK = '\033[92m'
+ WARN = '\033[93m'
+ FAIL = '\033[91m'
+ BOLD = '\033[1m'
+ UNDERLINE = '\033[4m'
+ END = '\033[0m'
class TerminalResult(CoverageResult):
- """Extension to CoverageResult adding basic terminal reporting."""
+ """Extension to CoverageResult adding basic terminal reporting."""
- def __init__(self, out, id_map):
- """Initialize the result object.
+ def __init__(self, out, id_map):
+ """Initialize the result object.
Args:
out (file-like): Output file to which terminal-colored live results will
be written.
id_map (callable): See AugmentedResult.__init__.
"""
- super(TerminalResult, self).__init__(id_map=id_map)
- self.out = out
-
- def startTestRun(self):
- """See unittest.TestResult.startTestRun."""
- super(TerminalResult, self).startTestRun()
- self.out.write(
- _Colors.HEADER +
- 'Testing gRPC Python...\n' +
- _Colors.END)
-
- def stopTestRun(self):
- """See unittest.TestResult.stopTestRun."""
- super(TerminalResult, self).stopTestRun()
- self.out.write(summary(self))
- self.out.flush()
-
- def addError(self, test, error):
- """See unittest.TestResult.addError."""
- super(TerminalResult, self).addError(test, error)
- self.out.write(
- _Colors.FAIL +
- 'ERROR {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
-
- def addFailure(self, test, error):
- """See unittest.TestResult.addFailure."""
- super(TerminalResult, self).addFailure(test, error)
- self.out.write(
- _Colors.FAIL +
- 'FAILURE {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
-
- def addSuccess(self, test):
- """See unittest.TestResult.addSuccess."""
- super(TerminalResult, self).addSuccess(test)
- self.out.write(
- _Colors.OK +
- 'SUCCESS {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
-
- def addSkip(self, test, reason):
- """See unittest.TestResult.addSkip."""
- super(TerminalResult, self).addSkip(test, reason)
- self.out.write(
- _Colors.INFO +
- 'SKIP {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
-
- def addExpectedFailure(self, test, error):
- """See unittest.TestResult.addExpectedFailure."""
- super(TerminalResult, self).addExpectedFailure(test, error)
- self.out.write(
- _Colors.INFO +
- 'FAILURE_OK {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
-
- def addUnexpectedSuccess(self, test):
- """See unittest.TestResult.addUnexpectedSuccess."""
- super(TerminalResult, self).addUnexpectedSuccess(test)
- self.out.write(
- _Colors.INFO +
- 'UNEXPECTED_OK {}\n'.format(test.id()) +
- _Colors.END)
- self.out.flush()
+ super(TerminalResult, self).__init__(id_map=id_map)
+ self.out = out
+
+ def startTestRun(self):
+ """See unittest.TestResult.startTestRun."""
+ super(TerminalResult, self).startTestRun()
+ self.out.write(_Colors.HEADER + 'Testing gRPC Python...\n' +
+ _Colors.END)
+
+ def stopTestRun(self):
+ """See unittest.TestResult.stopTestRun."""
+ super(TerminalResult, self).stopTestRun()
+ self.out.write(summary(self))
+ self.out.flush()
+
+ def addError(self, test, error):
+ """See unittest.TestResult.addError."""
+ super(TerminalResult, self).addError(test, error)
+ self.out.write(_Colors.FAIL + 'ERROR {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
+ def addFailure(self, test, error):
+ """See unittest.TestResult.addFailure."""
+ super(TerminalResult, self).addFailure(test, error)
+ self.out.write(_Colors.FAIL + 'FAILURE {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
+ def addSuccess(self, test):
+ """See unittest.TestResult.addSuccess."""
+ super(TerminalResult, self).addSuccess(test)
+ self.out.write(_Colors.OK + 'SUCCESS {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
+ def addSkip(self, test, reason):
+ """See unittest.TestResult.addSkip."""
+ super(TerminalResult, self).addSkip(test, reason)
+ self.out.write(_Colors.INFO + 'SKIP {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
+ def addExpectedFailure(self, test, error):
+ """See unittest.TestResult.addExpectedFailure."""
+ super(TerminalResult, self).addExpectedFailure(test, error)
+ self.out.write(_Colors.INFO + 'FAILURE_OK {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
+ def addUnexpectedSuccess(self, test):
+ """See unittest.TestResult.addUnexpectedSuccess."""
+ super(TerminalResult, self).addUnexpectedSuccess(test)
+ self.out.write(_Colors.INFO + 'UNEXPECTED_OK {}\n'.format(test.id()) +
+ _Colors.END)
+ self.out.flush()
+
def _traceback_string(type, value, trace):
- """Generate a descriptive string of a Python exception traceback.
+ """Generate a descriptive string of a Python exception traceback.
Args:
type (class): The type of the exception.
@@ -358,12 +363,13 @@ def _traceback_string(type, value, trace):
Returns:
str: Formatted exception descriptive string.
"""
- buffer = moves.cStringIO()
- traceback.print_exception(type, value, trace, file=buffer)
- return buffer.getvalue()
+ buffer = moves.cStringIO()
+ traceback.print_exception(type, value, trace, file=buffer)
+ return buffer.getvalue()
+
def summary(result):
- """A summary string of a result object.
+ """A summary string of a result object.
Args:
result (AugmentedResult): The result object to get the summary of.
@@ -371,62 +377,68 @@ def summary(result):
Returns:
str: The summary string.
"""
- assert isinstance(result, AugmentedResult)
- untested = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED))
- running = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.RUNNING))
- failures = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.FAILURE))
- errors = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.ERROR))
- successes = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS))
- skips = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.SKIP))
- expected_failures = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.EXPECTED_FAILURE))
- unexpected_successes = list(result.augmented_results(
- lambda case_result: case_result.kind is CaseResult.Kind.UNEXPECTED_SUCCESS))
- running_names = [case.name for case in running]
- finished_count = (len(failures) + len(errors) + len(successes) +
- len(expected_failures) + len(unexpected_successes))
- statistics = (
- '{finished} tests finished:\n'
- '\t{successful} successful\n'
- '\t{unsuccessful} unsuccessful\n'
- '\t{skipped} skipped\n'
- '\t{expected_fail} expected failures\n'
- '\t{unexpected_successful} unexpected successes\n'
- 'Interrupted Tests:\n'
- '\t{interrupted}\n'
- .format(finished=finished_count,
- successful=len(successes),
- unsuccessful=(len(failures)+len(errors)),
- skipped=len(skips),
- expected_fail=len(expected_failures),
- unexpected_successful=len(unexpected_successes),
- interrupted=str(running_names)))
- tracebacks = '\n\n'.join([
- (_Colors.FAIL + '{test_name}' + _Colors.END + '\n' +
- _Colors.BOLD + 'traceback:' + _Colors.END + '\n' +
- '{traceback}\n' +
- _Colors.BOLD + 'stdout:' + _Colors.END + '\n' +
- '{stdout}\n' +
- _Colors.BOLD + 'stderr:' + _Colors.END + '\n' +
- '{stderr}\n').format(
- test_name=result.name,
- traceback=_traceback_string(*result.traceback),
- stdout=result.stdout, stderr=result.stderr)
- for result in itertools.chain(failures, errors)
- ])
- notes = 'Unexpected successes: {}\n'.format([
- result.name for result in unexpected_successes])
- return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes
+ assert isinstance(result, AugmentedResult)
+ untested = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.UNTESTED))
+ running = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.RUNNING))
+ failures = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.FAILURE))
+ errors = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.ERROR))
+ successes = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.SUCCESS))
+ skips = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.SKIP))
+ expected_failures = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.EXPECTED_FAILURE
+ ))
+ unexpected_successes = list(
+ result.augmented_results(
+ lambda case_result: case_result.kind is CaseResult.Kind.UNEXPECTED_SUCCESS
+ ))
+ running_names = [case.name for case in running]
+ finished_count = (len(failures) + len(errors) + len(successes) +
+ len(expected_failures) + len(unexpected_successes))
+ statistics = ('{finished} tests finished:\n'
+ '\t{successful} successful\n'
+ '\t{unsuccessful} unsuccessful\n'
+ '\t{skipped} skipped\n'
+ '\t{expected_fail} expected failures\n'
+ '\t{unexpected_successful} unexpected successes\n'
+ 'Interrupted Tests:\n'
+ '\t{interrupted}\n'.format(
+ finished=finished_count,
+ successful=len(successes),
+ unsuccessful=(len(failures) + len(errors)),
+ skipped=len(skips),
+ expected_fail=len(expected_failures),
+ unexpected_successful=len(unexpected_successes),
+ interrupted=str(running_names)))
+ tracebacks = '\n\n'.join(
+ [(_Colors.FAIL + '{test_name}' + _Colors.END + '\n' + _Colors.BOLD +
+ 'traceback:' + _Colors.END + '\n' + '{traceback}\n' + _Colors.BOLD +
+ 'stdout:' + _Colors.END + '\n' + '{stdout}\n' + _Colors.BOLD +
+ 'stderr:' + _Colors.END + '\n' + '{stderr}\n').format(
+ test_name=result.name,
+ traceback=_traceback_string(*result.traceback),
+ stdout=result.stdout,
+ stderr=result.stderr)
+ for result in itertools.chain(failures, errors)])
+ notes = 'Unexpected successes: {}\n'.format(
+ [result.name for result in unexpected_successes])
+ return statistics + '\nErrors/Failures: \n' + tracebacks + '\n' + notes
def jenkins_junit_xml(result):
- """An XML tree object that when written is recognizable by Jenkins.
+ """An XML tree object that when written is recognizable by Jenkins.
Args:
result (AugmentedResult): The result object to get the junit xml output of.
@@ -434,20 +446,18 @@ def jenkins_junit_xml(result):
Returns:
ElementTree.ElementTree: The XML tree.
"""
- assert isinstance(result, AugmentedResult)
- root = ElementTree.Element('testsuites')
- suite = ElementTree.SubElement(root, 'testsuite', {
- 'name': 'Python gRPC tests',
- })
- for case in result.cases.values():
- if case.kind is CaseResult.Kind.SUCCESS:
- ElementTree.SubElement(suite, 'testcase', {
- 'name': case.name,
- })
- elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE):
- case_xml = ElementTree.SubElement(suite, 'testcase', {
- 'name': case.name,
- })
- error_xml = ElementTree.SubElement(case_xml, 'error', {})
- error_xml.text = ''.format(case.stderr, case.traceback)
- return ElementTree.ElementTree(element=root)
+ assert isinstance(result, AugmentedResult)
+ root = ElementTree.Element('testsuites')
+ suite = ElementTree.SubElement(root, 'testsuite', {
+ 'name': 'Python gRPC tests',
+ })
+ for case in result.cases.values():
+ if case.kind is CaseResult.Kind.SUCCESS:
+ ElementTree.SubElement(suite, 'testcase', {'name': case.name,})
+ elif case.kind in (CaseResult.Kind.ERROR, CaseResult.Kind.FAILURE):
+ case_xml = ElementTree.SubElement(suite, 'testcase', {
+ 'name': case.name,
+ })
+ error_xml = ElementTree.SubElement(case_xml, 'error', {})
+ error_xml.text = ''.format(case.stderr, case.traceback)
+ return ElementTree.ElementTree(element=root)
diff --git a/src/python/grpcio_tests/tests/_runner.py b/src/python/grpcio_tests/tests/_runner.py
index 926dcbe23a..59964b271c 100644
--- a/src/python/grpcio_tests/tests/_runner.py
+++ b/src/python/grpcio_tests/tests/_runner.py
@@ -49,7 +49,7 @@ from tests import _result
class CaptureFile(object):
- """A context-managed file to redirect output to a byte array.
+ """A context-managed file to redirect output to a byte array.
Use by invoking `start` (`__enter__`) and at some point invoking `stop`
(`__exit__`). At any point after the initial call to `start` call `output` to
@@ -66,57 +66,56 @@ class CaptureFile(object):
Only non-None when self is started.
"""
- def __init__(self, fd):
- self._redirected_fd = fd
- self._saved_fd = os.dup(self._redirected_fd)
- self._into_file = None
+ def __init__(self, fd):
+ self._redirected_fd = fd
+ self._saved_fd = os.dup(self._redirected_fd)
+ self._into_file = None
- def output(self):
- """Get all output from the redirected-to file if it exists."""
- if self._into_file:
- self._into_file.seek(0)
- return bytes(self._into_file.read())
- else:
- return bytes()
+ def output(self):
+ """Get all output from the redirected-to file if it exists."""
+ if self._into_file:
+ self._into_file.seek(0)
+ return bytes(self._into_file.read())
+ else:
+ return bytes()
- def start(self):
- """Start redirection of writes to the file descriptor."""
- self._into_file = tempfile.TemporaryFile()
- os.dup2(self._into_file.fileno(), self._redirected_fd)
+ def start(self):
+ """Start redirection of writes to the file descriptor."""
+ self._into_file = tempfile.TemporaryFile()
+ os.dup2(self._into_file.fileno(), self._redirected_fd)
- def stop(self):
- """Stop redirection of writes to the file descriptor."""
- # n.b. this dup2 call auto-closes self._redirected_fd
- os.dup2(self._saved_fd, self._redirected_fd)
+ def stop(self):
+ """Stop redirection of writes to the file descriptor."""
+ # n.b. this dup2 call auto-closes self._redirected_fd
+ os.dup2(self._saved_fd, self._redirected_fd)
- def write_bypass(self, value):
- """Bypass the redirection and write directly to the original file.
+ def write_bypass(self, value):
+ """Bypass the redirection and write directly to the original file.
Arguments:
value (str): What to write to the original file.
"""
- if six.PY3 and not isinstance(value, six.binary_type):
- value = bytes(value, 'ascii')
- if self._saved_fd is None:
- os.write(self._redirect_fd, value)
- else:
- os.write(self._saved_fd, value)
+ if six.PY3 and not isinstance(value, six.binary_type):
+ value = bytes(value, 'ascii')
+ if self._saved_fd is None:
+ os.write(self._redirect_fd, value)
+ else:
+ os.write(self._saved_fd, value)
- def __enter__(self):
- self.start()
- return self
+ def __enter__(self):
+ self.start()
+ return self
- def __exit__(self, type, value, traceback):
- self.stop()
+ def __exit__(self, type, value, traceback):
+ self.stop()
- def close(self):
- """Close any resources used by self not closed by stop()."""
- os.close(self._saved_fd)
+ def close(self):
+ """Close any resources used by self not closed by stop()."""
+ os.close(self._saved_fd)
-class AugmentedCase(collections.namedtuple('AugmentedCase', [
- 'case', 'id'])):
- """A test case with a guaranteed unique externally specified identifier.
+class AugmentedCase(collections.namedtuple('AugmentedCase', ['case', 'id'])):
+ """A test case with a guaranteed unique externally specified identifier.
Attributes:
case (unittest.TestCase): TestCase we're decorating with an additional
@@ -125,105 +124,107 @@ class AugmentedCase(collections.namedtuple('AugmentedCase', [
purposes.
"""
- def __new__(cls, case, id=None):
- if id is None:
- id = uuid.uuid4()
- return super(cls, AugmentedCase).__new__(cls, case, id)
+ def __new__(cls, case, id=None):
+ if id is None:
+ id = uuid.uuid4()
+ return super(cls, AugmentedCase).__new__(cls, case, id)
class Runner(object):
- def run(self, suite):
- """See setuptools' test_runner setup argument for information."""
- # only run test cases with id starting with given prefix
- testcase_filter = os.getenv('GRPC_PYTHON_TESTRUNNER_FILTER')
- filtered_cases = []
- for case in _loader.iterate_suite_cases(suite):
- if not testcase_filter or case.id().startswith(testcase_filter):
- filtered_cases.append(case)
-
- # Ensure that every test case has no collision with any other test case in
- # the augmented results.
- augmented_cases = [AugmentedCase(case, uuid.uuid4())
- for case in filtered_cases]
- case_id_by_case = dict((augmented_case.case, augmented_case.id)
- for augmented_case in augmented_cases)
- result_out = moves.cStringIO()
- result = _result.TerminalResult(
- result_out, id_map=lambda case: case_id_by_case[case])
- stdout_pipe = CaptureFile(sys.stdout.fileno())
- stderr_pipe = CaptureFile(sys.stderr.fileno())
- kill_flag = [False]
-
- def sigint_handler(signal_number, frame):
- if signal_number == signal.SIGINT:
- kill_flag[0] = True # Python 2.7 not having 'local'... :-(
- signal.signal(signal_number, signal.SIG_DFL)
-
- def fault_handler(signal_number, frame):
- stdout_pipe.write_bypass(
- 'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'
- .format(signal_number, stdout_pipe.output(),
- stderr_pipe.output()))
- os._exit(1)
-
- def check_kill_self():
- if kill_flag[0]:
- stdout_pipe.write_bypass('Stopping tests short...')
- result.stopTestRun()
- stdout_pipe.write_bypass(result_out.getvalue())
- stdout_pipe.write_bypass(
- '\ninterrupted stdout:\n{}\n'.format(stdout_pipe.output().decode()))
- stderr_pipe.write_bypass(
- '\ninterrupted stderr:\n{}\n'.format(stderr_pipe.output().decode()))
- os._exit(1)
- def try_set_handler(name, handler):
- try:
- signal.signal(getattr(signal, name), handler)
- except AttributeError:
- pass
- try_set_handler('SIGINT', sigint_handler)
- try_set_handler('SIGSEGV', fault_handler)
- try_set_handler('SIGBUS', fault_handler)
- try_set_handler('SIGABRT', fault_handler)
- try_set_handler('SIGFPE', fault_handler)
- try_set_handler('SIGILL', fault_handler)
- # Sometimes output will lag after a test has successfully finished; we
- # ignore such writes to our pipes.
- try_set_handler('SIGPIPE', signal.SIG_IGN)
-
- # Run the tests
- result.startTestRun()
- for augmented_case in augmented_cases:
- sys.stdout.write('Running {}\n'.format(augmented_case.case.id()))
- sys.stdout.flush()
- case_thread = threading.Thread(
- target=augmented_case.case.run, args=(result,))
- try:
- with stdout_pipe, stderr_pipe:
- case_thread.start()
- while case_thread.is_alive():
+ def run(self, suite):
+ """See setuptools' test_runner setup argument for information."""
+ # only run test cases with id starting with given prefix
+ testcase_filter = os.getenv('GRPC_PYTHON_TESTRUNNER_FILTER')
+ filtered_cases = []
+ for case in _loader.iterate_suite_cases(suite):
+ if not testcase_filter or case.id().startswith(testcase_filter):
+ filtered_cases.append(case)
+
+ # Ensure that every test case has no collision with any other test case in
+ # the augmented results.
+ augmented_cases = [
+ AugmentedCase(case, uuid.uuid4()) for case in filtered_cases
+ ]
+ case_id_by_case = dict((augmented_case.case, augmented_case.id)
+ for augmented_case in augmented_cases)
+ result_out = moves.cStringIO()
+ result = _result.TerminalResult(
+ result_out, id_map=lambda case: case_id_by_case[case])
+ stdout_pipe = CaptureFile(sys.stdout.fileno())
+ stderr_pipe = CaptureFile(sys.stderr.fileno())
+ kill_flag = [False]
+
+ def sigint_handler(signal_number, frame):
+ if signal_number == signal.SIGINT:
+ kill_flag[0] = True # Python 2.7 not having 'local'... :-(
+ signal.signal(signal_number, signal.SIG_DFL)
+
+ def fault_handler(signal_number, frame):
+ stdout_pipe.write_bypass(
+ 'Received fault signal {}\nstdout:\n{}\n\nstderr:{}\n'.format(
+ signal_number, stdout_pipe.output(), stderr_pipe.output()))
+ os._exit(1)
+
+ def check_kill_self():
+ if kill_flag[0]:
+ stdout_pipe.write_bypass('Stopping tests short...')
+ result.stopTestRun()
+ stdout_pipe.write_bypass(result_out.getvalue())
+ stdout_pipe.write_bypass('\ninterrupted stdout:\n{}\n'.format(
+ stdout_pipe.output().decode()))
+ stderr_pipe.write_bypass('\ninterrupted stderr:\n{}\n'.format(
+ stderr_pipe.output().decode()))
+ os._exit(1)
+
+ def try_set_handler(name, handler):
+ try:
+ signal.signal(getattr(signal, name), handler)
+ except AttributeError:
+ pass
+
+ try_set_handler('SIGINT', sigint_handler)
+ try_set_handler('SIGSEGV', fault_handler)
+ try_set_handler('SIGBUS', fault_handler)
+ try_set_handler('SIGABRT', fault_handler)
+ try_set_handler('SIGFPE', fault_handler)
+ try_set_handler('SIGILL', fault_handler)
+ # Sometimes output will lag after a test has successfully finished; we
+ # ignore such writes to our pipes.
+ try_set_handler('SIGPIPE', signal.SIG_IGN)
+
+ # Run the tests
+ result.startTestRun()
+ for augmented_case in augmented_cases:
+ sys.stdout.write('Running {}\n'.format(augmented_case.case.id(
+ )))
+ sys.stdout.flush()
+ case_thread = threading.Thread(
+ target=augmented_case.case.run, args=(result,))
+ try:
+ with stdout_pipe, stderr_pipe:
+ case_thread.start()
+ while case_thread.is_alive():
+ check_kill_self()
+ time.sleep(0)
+ case_thread.join()
+ except:
+ # re-raise the exception after forcing the with-block to end
+ raise
+ result.set_output(augmented_case.case,
+ stdout_pipe.output(), stderr_pipe.output())
+ sys.stdout.write(result_out.getvalue())
+ sys.stdout.flush()
+ result_out.truncate(0)
check_kill_self()
- time.sleep(0)
- case_thread.join()
- except:
- # re-raise the exception after forcing the with-block to end
- raise
- result.set_output(
- augmented_case.case, stdout_pipe.output(), stderr_pipe.output())
- sys.stdout.write(result_out.getvalue())
- sys.stdout.flush()
- result_out.truncate(0)
- check_kill_self()
- result.stopTestRun()
- stdout_pipe.close()
- stderr_pipe.close()
-
- # Report results
- sys.stdout.write(result_out.getvalue())
- sys.stdout.flush()
- signal.signal(signal.SIGINT, signal.SIG_DFL)
- with open('report.xml', 'wb') as report_xml_file:
- _result.jenkins_junit_xml(result).write(report_xml_file)
- return result
-
+ result.stopTestRun()
+ stdout_pipe.close()
+ stderr_pipe.close()
+
+ # Report results
+ sys.stdout.write(result_out.getvalue())
+ sys.stdout.flush()
+ signal.signal(signal.SIGINT, signal.SIG_DFL)
+ with open('report.xml', 'wb') as report_xml_file:
+ _result.jenkins_junit_xml(result).write(report_xml_file)
+ return result
diff --git a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
index 5dde72b169..363b4c5f99 100644
--- a/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
+++ b/src/python/grpcio_tests/tests/health_check/_health_servicer_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc_health.v1.health."""
import unittest
@@ -41,55 +40,55 @@ from tests.unit.framework.common import test_constants
class HealthServicerTest(unittest.TestCase):
- def setUp(self):
- servicer = health.HealthServicer()
- servicer.set('', health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceServing',
- health_pb2.HealthCheckResponse.SERVING)
- servicer.set('grpc.test.TestServiceUnknown',
- health_pb2.HealthCheckResponse.UNKNOWN)
- servicer.set('grpc.test.TestServiceNotServing',
- health_pb2.HealthCheckResponse.NOT_SERVING)
- server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(server_pool)
- port = self._server.add_insecure_port('[::]:0')
- health_pb2.add_HealthServicer_to_server(servicer, self._server)
- self._server.start()
+ def setUp(self):
+ servicer = health.HealthServicer()
+ servicer.set('', health_pb2.HealthCheckResponse.SERVING)
+ servicer.set('grpc.test.TestServiceServing',
+ health_pb2.HealthCheckResponse.SERVING)
+ servicer.set('grpc.test.TestServiceUnknown',
+ health_pb2.HealthCheckResponse.UNKNOWN)
+ servicer.set('grpc.test.TestServiceNotServing',
+ health_pb2.HealthCheckResponse.NOT_SERVING)
+ server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ health_pb2.add_HealthServicer_to_server(servicer, self._server)
+ self._server.start()
+
+ channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = health_pb2.HealthStub(channel)
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = health_pb2.HealthStub(channel)
+ def test_empty_service(self):
+ request = health_pb2.HealthCheckRequest()
+ resp = self._stub.Check(request)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_empty_service(self):
- request = health_pb2.HealthCheckRequest()
- resp = self._stub.Check(request)
- self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
+ def test_serving_service(self):
+ request = health_pb2.HealthCheckRequest(
+ service='grpc.test.TestServiceServing')
+ resp = self._stub.Check(request)
+ self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
- def test_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceServing')
- resp = self._stub.Check(request)
- self.assertEqual(health_pb2.HealthCheckResponse.SERVING, resp.status)
+ def test_unknown_serivce(self):
+ request = health_pb2.HealthCheckRequest(
+ service='grpc.test.TestServiceUnknown')
+ resp = self._stub.Check(request)
+ self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status)
- def test_unknown_serivce(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceUnknown')
- resp = self._stub.Check(request)
- self.assertEqual(health_pb2.HealthCheckResponse.UNKNOWN, resp.status)
+ def test_not_serving_service(self):
+ request = health_pb2.HealthCheckRequest(
+ service='grpc.test.TestServiceNotServing')
+ resp = self._stub.Check(request)
+ self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING,
+ resp.status)
- def test_not_serving_service(self):
- request = health_pb2.HealthCheckRequest(
- service='grpc.test.TestServiceNotServing')
- resp = self._stub.Check(request)
- self.assertEqual(health_pb2.HealthCheckResponse.NOT_SERVING, resp.status)
+ def test_not_found_service(self):
+ request = health_pb2.HealthCheckRequest(service='not-found')
+ with self.assertRaises(grpc.RpcError) as context:
+ resp = self._stub.Check(request)
- def test_not_found_service(self):
- request = health_pb2.HealthCheckRequest(
- service='not-found')
- with self.assertRaises(grpc.RpcError) as context:
- resp = self._stub.Check(request)
-
- self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
+ self.assertEqual(grpc.StatusCode.NOT_FOUND, context.exception.code())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/http2/_negative_http2_client.py b/src/python/grpcio_tests/tests/http2/_negative_http2_client.py
index f8604683b3..c192d827c4 100644
--- a/src/python/grpcio_tests/tests/http2/_negative_http2_client.py
+++ b/src/python/grpcio_tests/tests/http2/_negative_http2_client.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""The Python client used to test negative http2 conditions."""
import argparse
@@ -35,29 +34,32 @@ import grpc
from src.proto.grpc.testing import test_pb2
from src.proto.grpc.testing import messages_pb2
+
def _validate_payload_type_and_length(response, expected_type, expected_length):
- if response.payload.type is not expected_type:
- raise ValueError(
- 'expected payload type %s, got %s' %
- (expected_type, type(response.payload.type)))
- elif len(response.payload.body) != expected_length:
- raise ValueError(
- 'expected payload body size %d, got %d' %
- (expected_length, len(response.payload.body)))
+ if response.payload.type is not expected_type:
+ raise ValueError('expected payload type %s, got %s' %
+ (expected_type, type(response.payload.type)))
+ elif len(response.payload.body) != expected_length:
+ raise ValueError('expected payload body size %d, got %d' %
+ (expected_length, len(response.payload.body)))
+
def _expect_status_code(call, expected_code):
- if call.code() != expected_code:
- raise ValueError(
- 'expected code %s, got %s' % (expected_code, call.code()))
+ if call.code() != expected_code:
+ raise ValueError('expected code %s, got %s' %
+ (expected_code, call.code()))
+
def _expect_status_details(call, expected_details):
- if call.details() != expected_details:
- raise ValueError(
- 'expected message %s, got %s' % (expected_details, call.details()))
+ if call.details() != expected_details:
+ raise ValueError('expected message %s, got %s' %
+ (expected_details, call.details()))
+
def _validate_status_code_and_details(call, expected_code, expected_details):
- _expect_status_code(call, expected_code)
- _expect_status_details(call, expected_details)
+ _expect_status_code(call, expected_code)
+ _expect_status_details(call, expected_details)
+
# common requests
_REQUEST_SIZE = 314159
@@ -68,86 +70,103 @@ _SIMPLE_REQUEST = messages_pb2.SimpleRequest(
response_size=_RESPONSE_SIZE,
payload=messages_pb2.Payload(body=b'\x00' * _REQUEST_SIZE))
+
def _goaway(stub):
- first_response = stub.UnaryCall(_SIMPLE_REQUEST)
- _validate_payload_type_and_length(first_response,
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
- second_response = stub.UnaryCall(_SIMPLE_REQUEST)
- _validate_payload_type_and_length(second_response,
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+ first_response = stub.UnaryCall(_SIMPLE_REQUEST)
+ _validate_payload_type_and_length(first_response, messages_pb2.COMPRESSABLE,
+ _RESPONSE_SIZE)
+ second_response = stub.UnaryCall(_SIMPLE_REQUEST)
+ _validate_payload_type_and_length(second_response,
+ messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+
def _rst_after_header(stub):
- resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
- _validate_status_code_and_details(resp_future, grpc.StatusCode.UNAVAILABLE, "")
+ resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
+ _validate_status_code_and_details(resp_future, grpc.StatusCode.UNAVAILABLE,
+ "")
+
def _rst_during_data(stub):
- resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
- _validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
+ resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
+ _validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
+
def _rst_after_data(stub):
- resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
- _validate_payload_type_and_length(next(resp_future),
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
- _validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
+ resp_future = stub.UnaryCall.future(_SIMPLE_REQUEST)
+ _validate_payload_type_and_length(
+ next(resp_future), messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+ _validate_status_code_and_details(resp_future, grpc.StatusCode.UNKNOWN, "")
+
def _ping(stub):
- response = stub.UnaryCall(_SIMPLE_REQUEST)
- _validate_payload_type_and_length(response,
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+ response = stub.UnaryCall(_SIMPLE_REQUEST)
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
+ _RESPONSE_SIZE)
+
def _max_streams(stub):
- # send one req to ensure server sets MAX_STREAMS
- response = stub.UnaryCall(_SIMPLE_REQUEST)
- _validate_payload_type_and_length(response,
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
-
- # give the streams a workout
- futures = []
- for _ in range(15):
- futures.append(stub.UnaryCall.future(_SIMPLE_REQUEST))
- for future in futures:
- _validate_payload_type_and_length(future.result(),
- messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+ # send one req to ensure server sets MAX_STREAMS
+ response = stub.UnaryCall(_SIMPLE_REQUEST)
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
+ _RESPONSE_SIZE)
+
+ # give the streams a workout
+ futures = []
+ for _ in range(15):
+ futures.append(stub.UnaryCall.future(_SIMPLE_REQUEST))
+ for future in futures:
+ _validate_payload_type_and_length(
+ future.result(), messages_pb2.COMPRESSABLE, _RESPONSE_SIZE)
+
def _run_test_case(test_case, stub):
- if test_case == 'goaway':
- _goaway(stub)
- elif test_case == 'rst_after_header':
- _rst_after_header(stub)
- elif test_case == 'rst_during_data':
- _rst_during_data(stub)
- elif test_case == 'rst_after_data':
- _rst_after_data(stub)
- elif test_case =='ping':
- _ping(stub)
- elif test_case == 'max_streams':
- _max_streams(stub)
- else:
- raise ValueError("Invalid test case: %s" % test_case)
+ if test_case == 'goaway':
+ _goaway(stub)
+ elif test_case == 'rst_after_header':
+ _rst_after_header(stub)
+ elif test_case == 'rst_during_data':
+ _rst_during_data(stub)
+ elif test_case == 'rst_after_data':
+ _rst_after_data(stub)
+ elif test_case == 'ping':
+ _ping(stub)
+ elif test_case == 'max_streams':
+ _max_streams(stub)
+ else:
+ raise ValueError("Invalid test case: %s" % test_case)
+
def _args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--server_host', help='the host to which to connect', type=str,
- default="127.0.0.1")
- parser.add_argument(
- '--server_port', help='the port to which to connect', type=int,
- default="8080")
- parser.add_argument(
- '--test_case', help='the test case to execute', type=str,
- default="goaway")
- return parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--server_host',
+ help='the host to which to connect',
+ type=str,
+ default="127.0.0.1")
+ parser.add_argument(
+ '--server_port',
+ help='the port to which to connect',
+ type=int,
+ default="8080")
+ parser.add_argument(
+ '--test_case',
+ help='the test case to execute',
+ type=str,
+ default="goaway")
+ return parser.parse_args()
+
def _stub(server_host, server_port):
- target = '{}:{}'.format(server_host, server_port)
- channel = grpc.insecure_channel(target)
- return test_pb2.TestServiceStub(channel)
+ target = '{}:{}'.format(server_host, server_port)
+ channel = grpc.insecure_channel(target)
+ return test_pb2.TestServiceStub(channel)
+
def main():
- args = _args()
- stub = _stub(args.server_host, args.server_port)
- _run_test_case(args.test_case, stub)
+ args = _args()
+ stub = _stub(args.server_host, args.server_port)
+ _run_test_case(args.test_case, stub)
if __name__ == '__main__':
- main()
+ main()
diff --git a/src/python/grpcio_tests/tests/interop/__init__.py b/src/python/grpcio_tests/tests/interop/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/interop/__init__.py
+++ b/src/python/grpcio_tests/tests/interop/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py b/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py
index 4fb22b4d9d..58f3b364ba 100644
--- a/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py
+++ b/src/python/grpcio_tests/tests/interop/_insecure_intraop_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Insecure client-server interoperability as a unit test."""
from concurrent import futures
@@ -40,19 +39,18 @@ from tests.interop import methods
from tests.interop import server
-class InsecureIntraopTest(
- _intraop_test_case.IntraopTestCase,
- unittest.TestCase):
+class InsecureIntraopTest(_intraop_test_case.IntraopTestCase,
+ unittest.TestCase):
- def setUp(self):
- self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- test_pb2.add_TestServiceServicer_to_server(
- methods.TestService(), self.server)
- port = self.server.add_insecure_port('[::]:0')
- self.server.start()
- self.stub = test_pb2.TestServiceStub(
- grpc.insecure_channel('localhost:{}'.format(port)))
+ def setUp(self):
+ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ test_pb2.add_TestServiceServicer_to_server(methods.TestService(),
+ self.server)
+ port = self.server.add_insecure_port('[::]:0')
+ self.server.start()
+ self.stub = test_pb2.TestServiceStub(
+ grpc.insecure_channel('localhost:{}'.format(port)))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/interop/_intraop_test_case.py b/src/python/grpcio_tests/tests/interop/_intraop_test_case.py
index fe1c173992..424f93980c 100644
--- a/src/python/grpcio_tests/tests/interop/_intraop_test_case.py
+++ b/src/python/grpcio_tests/tests/interop/_intraop_test_case.py
@@ -26,39 +26,41 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Common code for unit tests of the interoperability test code."""
from tests.interop import methods
class IntraopTestCase(object):
- """Unit test methods.
+ """Unit test methods.
This class must be mixed in with unittest.TestCase and a class that defines
setUp and tearDown methods that manage a stub attribute.
"""
- def testEmptyUnary(self):
- methods.TestCase.EMPTY_UNARY.test_interoperability(self.stub, None)
+ def testEmptyUnary(self):
+ methods.TestCase.EMPTY_UNARY.test_interoperability(self.stub, None)
- def testLargeUnary(self):
- methods.TestCase.LARGE_UNARY.test_interoperability(self.stub, None)
+ def testLargeUnary(self):
+ methods.TestCase.LARGE_UNARY.test_interoperability(self.stub, None)
- def testServerStreaming(self):
- methods.TestCase.SERVER_STREAMING.test_interoperability(self.stub, None)
+ def testServerStreaming(self):
+ methods.TestCase.SERVER_STREAMING.test_interoperability(self.stub, None)
- def testClientStreaming(self):
- methods.TestCase.CLIENT_STREAMING.test_interoperability(self.stub, None)
+ def testClientStreaming(self):
+ methods.TestCase.CLIENT_STREAMING.test_interoperability(self.stub, None)
- def testPingPong(self):
- methods.TestCase.PING_PONG.test_interoperability(self.stub, None)
+ def testPingPong(self):
+ methods.TestCase.PING_PONG.test_interoperability(self.stub, None)
- def testCancelAfterBegin(self):
- methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(self.stub, None)
+ def testCancelAfterBegin(self):
+ methods.TestCase.CANCEL_AFTER_BEGIN.test_interoperability(self.stub,
+ None)
- def testCancelAfterFirstResponse(self):
- methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability(self.stub, None)
+ def testCancelAfterFirstResponse(self):
+ methods.TestCase.CANCEL_AFTER_FIRST_RESPONSE.test_interoperability(
+ self.stub, None)
- def testTimeoutOnSleepingServer(self):
- methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability(self.stub, None)
+ def testTimeoutOnSleepingServer(self):
+ methods.TestCase.TIMEOUT_ON_SLEEPING_SERVER.test_interoperability(
+ self.stub, None)
diff --git a/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py b/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py
index 3665c69726..b28406ed3f 100644
--- a/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py
+++ b/src/python/grpcio_tests/tests/interop/_secure_intraop_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Secure client-server interoperability as a unit test."""
from concurrent import futures
@@ -42,24 +41,24 @@ from tests.interop import resources
_SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
-class SecureIntraopTest(
- _intraop_test_case.IntraopTestCase,
- unittest.TestCase):
+class SecureIntraopTest(_intraop_test_case.IntraopTestCase, unittest.TestCase):
- def setUp(self):
- self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- test_pb2.add_TestServiceServicer_to_server(
- methods.TestService(), self.server)
- port = self.server.add_secure_port(
- '[::]:0', grpc.ssl_server_credentials(
- [(resources.private_key(), resources.certificate_chain())]))
- self.server.start()
- self.stub = test_pb2.TestServiceStub(
- grpc.secure_channel(
- 'localhost:{}'.format(port),
- grpc.ssl_channel_credentials(resources.test_root_certificates()),
- (('grpc.ssl_target_name_override', _SERVER_HOST_OVERRIDE,),)))
+ def setUp(self):
+ self.server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ test_pb2.add_TestServiceServicer_to_server(methods.TestService(),
+ self.server)
+ port = self.server.add_secure_port(
+ '[::]:0',
+ grpc.ssl_server_credentials(
+ [(resources.private_key(), resources.certificate_chain())]))
+ self.server.start()
+ self.stub = test_pb2.TestServiceStub(
+ grpc.secure_channel('localhost:{}'.format(port),
+ grpc.ssl_channel_credentials(
+ resources.test_root_certificates()), ((
+ 'grpc.ssl_target_name_override',
+ _SERVER_HOST_OVERRIDE,),)))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/interop/client.py b/src/python/grpcio_tests/tests/interop/client.py
index afaa466254..f177896e8e 100644
--- a/src/python/grpcio_tests/tests/interop/client.py
+++ b/src/python/grpcio_tests/tests/interop/client.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""The Python implementation of the GRPC interoperability test client."""
import argparse
@@ -41,93 +40,107 @@ from tests.interop import resources
def _args():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--server_host', help='the host to which to connect', type=str,
- default="127.0.0.1")
- parser.add_argument(
- '--server_port', help='the port to which to connect', type=int)
- parser.add_argument(
- '--test_case', help='the test case to execute', type=str,
- default="large_unary")
- parser.add_argument(
- '--use_tls', help='require a secure connection', default=False,
- type=resources.parse_bool)
- parser.add_argument(
- '--use_test_ca', help='replace platform root CAs with ca.pem',
- default=False, type=resources.parse_bool)
- parser.add_argument(
- '--server_host_override', default="foo.test.google.fr",
- help='the server host to which to claim to connect', type=str)
- parser.add_argument('--oauth_scope', help='scope for OAuth tokens', type=str)
- parser.add_argument(
- '--default_service_account',
- help='email address of the default service account', type=str)
- return parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--server_host',
+ help='the host to which to connect',
+ type=str,
+ default="127.0.0.1")
+ parser.add_argument(
+ '--server_port', help='the port to which to connect', type=int)
+ parser.add_argument(
+ '--test_case',
+ help='the test case to execute',
+ type=str,
+ default="large_unary")
+ parser.add_argument(
+ '--use_tls',
+ help='require a secure connection',
+ default=False,
+ type=resources.parse_bool)
+ parser.add_argument(
+ '--use_test_ca',
+ help='replace platform root CAs with ca.pem',
+ default=False,
+ type=resources.parse_bool)
+ parser.add_argument(
+ '--server_host_override',
+ default="foo.test.google.fr",
+ help='the server host to which to claim to connect',
+ type=str)
+ parser.add_argument(
+ '--oauth_scope', help='scope for OAuth tokens', type=str)
+ parser.add_argument(
+ '--default_service_account',
+ help='email address of the default service account',
+ type=str)
+ return parser.parse_args()
def _application_default_credentials():
- return oauth2client_client.GoogleCredentials.get_application_default()
+ return oauth2client_client.GoogleCredentials.get_application_default()
def _stub(args):
- target = '{}:{}'.format(args.server_host, args.server_port)
- if args.test_case == 'oauth2_auth_token':
- google_credentials = _application_default_credentials()
- scoped_credentials = google_credentials.create_scoped([args.oauth_scope])
- access_token = scoped_credentials.get_access_token().access_token
- call_credentials = grpc.access_token_call_credentials(access_token)
- elif args.test_case == 'compute_engine_creds':
- google_credentials = _application_default_credentials()
- scoped_credentials = google_credentials.create_scoped([args.oauth_scope])
- # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
- # remaining use of the Beta API.
- call_credentials = implementations.google_call_credentials(
- scoped_credentials)
- elif args.test_case == 'jwt_token_creds':
- google_credentials = _application_default_credentials()
- # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
- # remaining use of the Beta API.
- call_credentials = implementations.google_call_credentials(
- google_credentials)
- else:
- call_credentials = None
- if args.use_tls:
- if args.use_test_ca:
- root_certificates = resources.test_root_certificates()
+ target = '{}:{}'.format(args.server_host, args.server_port)
+ if args.test_case == 'oauth2_auth_token':
+ google_credentials = _application_default_credentials()
+ scoped_credentials = google_credentials.create_scoped(
+ [args.oauth_scope])
+ access_token = scoped_credentials.get_access_token().access_token
+ call_credentials = grpc.access_token_call_credentials(access_token)
+ elif args.test_case == 'compute_engine_creds':
+ google_credentials = _application_default_credentials()
+ scoped_credentials = google_credentials.create_scoped(
+ [args.oauth_scope])
+ # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
+ # remaining use of the Beta API.
+ call_credentials = implementations.google_call_credentials(
+ scoped_credentials)
+ elif args.test_case == 'jwt_token_creds':
+ google_credentials = _application_default_credentials()
+ # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
+ # remaining use of the Beta API.
+ call_credentials = implementations.google_call_credentials(
+ google_credentials)
else:
- root_certificates = None # will load default roots.
-
- channel_credentials = grpc.ssl_channel_credentials(root_certificates)
- if call_credentials is not None:
- channel_credentials = grpc.composite_channel_credentials(
- channel_credentials, call_credentials)
-
- channel = grpc.secure_channel(
- target, channel_credentials,
- (('grpc.ssl_target_name_override', args.server_host_override,),))
- else:
- channel = grpc.insecure_channel(target)
- if args.test_case == "unimplemented_service":
- return test_pb2.UnimplementedServiceStub(channel)
- else:
- return test_pb2.TestServiceStub(channel)
+ call_credentials = None
+ if args.use_tls:
+ if args.use_test_ca:
+ root_certificates = resources.test_root_certificates()
+ else:
+ root_certificates = None # will load default roots.
+
+ channel_credentials = grpc.ssl_channel_credentials(root_certificates)
+ if call_credentials is not None:
+ channel_credentials = grpc.composite_channel_credentials(
+ channel_credentials, call_credentials)
+
+ channel = grpc.secure_channel(target, channel_credentials, ((
+ 'grpc.ssl_target_name_override',
+ args.server_host_override,),))
+ else:
+ channel = grpc.insecure_channel(target)
+ if args.test_case == "unimplemented_service":
+ return test_pb2.UnimplementedServiceStub(channel)
+ else:
+ return test_pb2.TestServiceStub(channel)
def _test_case_from_arg(test_case_arg):
- for test_case in methods.TestCase:
- if test_case_arg == test_case.value:
- return test_case
- else:
- raise ValueError('No test case "%s"!' % test_case_arg)
+ for test_case in methods.TestCase:
+ if test_case_arg == test_case.value:
+ return test_case
+ else:
+ raise ValueError('No test case "%s"!' % test_case_arg)
def test_interoperability():
- args = _args()
- stub = _stub(args)
- test_case = _test_case_from_arg(args.test_case)
- test_case.test_interoperability(stub, args)
+ args = _args()
+ stub = _stub(args)
+ test_case = _test_case_from_arg(args.test_case)
+ test_case.test_interoperability(stub, args)
if __name__ == '__main__':
- test_interoperability()
+ test_interoperability()
diff --git a/src/python/grpcio_tests/tests/interop/methods.py b/src/python/grpcio_tests/tests/interop/methods.py
index 9038ae5751..e1f8722168 100644
--- a/src/python/grpcio_tests/tests/interop/methods.py
+++ b/src/python/grpcio_tests/tests/interop/methods.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Implementations of interoperability test methods."""
import enum
@@ -46,463 +45,483 @@ from src.proto.grpc.testing import test_pb2
_INITIAL_METADATA_KEY = "x-grpc-test-echo-initial"
_TRAILING_METADATA_KEY = "x-grpc-test-echo-trailing-bin"
+
def _maybe_echo_metadata(servicer_context):
- """Copies metadata from request to response if it is present."""
- invocation_metadata = dict(servicer_context.invocation_metadata())
- if _INITIAL_METADATA_KEY in invocation_metadata:
- initial_metadatum = (
- _INITIAL_METADATA_KEY, invocation_metadata[_INITIAL_METADATA_KEY])
- servicer_context.send_initial_metadata((initial_metadatum,))
- if _TRAILING_METADATA_KEY in invocation_metadata:
- trailing_metadatum = (
- _TRAILING_METADATA_KEY, invocation_metadata[_TRAILING_METADATA_KEY])
- servicer_context.set_trailing_metadata((trailing_metadatum,))
+ """Copies metadata from request to response if it is present."""
+ invocation_metadata = dict(servicer_context.invocation_metadata())
+ if _INITIAL_METADATA_KEY in invocation_metadata:
+ initial_metadatum = (_INITIAL_METADATA_KEY,
+ invocation_metadata[_INITIAL_METADATA_KEY])
+ servicer_context.send_initial_metadata((initial_metadatum,))
+ if _TRAILING_METADATA_KEY in invocation_metadata:
+ trailing_metadatum = (_TRAILING_METADATA_KEY,
+ invocation_metadata[_TRAILING_METADATA_KEY])
+ servicer_context.set_trailing_metadata((trailing_metadatum,))
+
def _maybe_echo_status_and_message(request, servicer_context):
- """Sets the response context code and details if the request asks for them"""
- if request.HasField('response_status'):
- servicer_context.set_code(request.response_status.code)
- servicer_context.set_details(request.response_status.message)
+ """Sets the response context code and details if the request asks for them"""
+ if request.HasField('response_status'):
+ servicer_context.set_code(request.response_status.code)
+ servicer_context.set_details(request.response_status.message)
+
class TestService(test_pb2.TestServiceServicer):
- def EmptyCall(self, request, context):
- _maybe_echo_metadata(context)
- return empty_pb2.Empty()
+ def EmptyCall(self, request, context):
+ _maybe_echo_metadata(context)
+ return empty_pb2.Empty()
- def UnaryCall(self, request, context):
- _maybe_echo_metadata(context)
- _maybe_echo_status_and_message(request, context)
- return messages_pb2.SimpleResponse(
- payload=messages_pb2.Payload(
+ def UnaryCall(self, request, context):
+ _maybe_echo_metadata(context)
+ _maybe_echo_status_and_message(request, context)
+ return messages_pb2.SimpleResponse(payload=messages_pb2.Payload(
type=messages_pb2.COMPRESSABLE,
body=b'\x00' * request.response_size))
- def StreamingOutputCall(self, request, context):
- _maybe_echo_status_and_message(request, context)
- for response_parameters in request.response_parameters:
- yield messages_pb2.StreamingOutputCallResponse(
- payload=messages_pb2.Payload(
- type=request.response_type,
- body=b'\x00' * response_parameters.size))
-
- def StreamingInputCall(self, request_iterator, context):
- aggregate_size = 0
- for request in request_iterator:
- if request.payload is not None and request.payload.body:
- aggregate_size += len(request.payload.body)
- return messages_pb2.StreamingInputCallResponse(
- aggregated_payload_size=aggregate_size)
-
- def FullDuplexCall(self, request_iterator, context):
- _maybe_echo_metadata(context)
- for request in request_iterator:
- _maybe_echo_status_and_message(request, context)
- for response_parameters in request.response_parameters:
- yield messages_pb2.StreamingOutputCallResponse(
- payload=messages_pb2.Payload(
- type=request.payload.type,
- body=b'\x00' * response_parameters.size))
-
- # NOTE(nathaniel): Apparently this is the same as the full-duplex call?
- # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...
- def HalfDuplexCall(self, request_iterator, context):
- return self.FullDuplexCall(request_iterator, context)
+ def StreamingOutputCall(self, request, context):
+ _maybe_echo_status_and_message(request, context)
+ for response_parameters in request.response_parameters:
+ yield messages_pb2.StreamingOutputCallResponse(
+ payload=messages_pb2.Payload(
+ type=request.response_type,
+ body=b'\x00' * response_parameters.size))
+
+ def StreamingInputCall(self, request_iterator, context):
+ aggregate_size = 0
+ for request in request_iterator:
+ if request.payload is not None and request.payload.body:
+ aggregate_size += len(request.payload.body)
+ return messages_pb2.StreamingInputCallResponse(
+ aggregated_payload_size=aggregate_size)
+
+ def FullDuplexCall(self, request_iterator, context):
+ _maybe_echo_metadata(context)
+ for request in request_iterator:
+ _maybe_echo_status_and_message(request, context)
+ for response_parameters in request.response_parameters:
+ yield messages_pb2.StreamingOutputCallResponse(
+ payload=messages_pb2.Payload(
+ type=request.payload.type,
+ body=b'\x00' * response_parameters.size))
+
+ # NOTE(nathaniel): Apparently this is the same as the full-duplex call?
+ # NOTE(atash): It isn't even called in the interop spec (Oct 22 2015)...
+ def HalfDuplexCall(self, request_iterator, context):
+ return self.FullDuplexCall(request_iterator, context)
def _expect_status_code(call, expected_code):
- if call.code() != expected_code:
- raise ValueError(
- 'expected code %s, got %s' % (expected_code, call.code()))
+ if call.code() != expected_code:
+ raise ValueError('expected code %s, got %s' %
+ (expected_code, call.code()))
def _expect_status_details(call, expected_details):
- if call.details() != expected_details:
- raise ValueError(
- 'expected message %s, got %s' % (expected_details, call.details()))
+ if call.details() != expected_details:
+ raise ValueError('expected message %s, got %s' %
+ (expected_details, call.details()))
def _validate_status_code_and_details(call, expected_code, expected_details):
- _expect_status_code(call, expected_code)
- _expect_status_details(call, expected_details)
+ _expect_status_code(call, expected_code)
+ _expect_status_details(call, expected_details)
def _validate_payload_type_and_length(response, expected_type, expected_length):
- if response.payload.type is not expected_type:
- raise ValueError(
- 'expected payload type %s, got %s' %
- (expected_type, type(response.payload.type)))
- elif len(response.payload.body) != expected_length:
- raise ValueError(
- 'expected payload body size %d, got %d' %
- (expected_length, len(response.payload.body)))
-
-
-def _large_unary_common_behavior(
- stub, fill_username, fill_oauth_scope, call_credentials):
- size = 314159
- request = messages_pb2.SimpleRequest(
- response_type=messages_pb2.COMPRESSABLE, response_size=size,
- payload=messages_pb2.Payload(body=b'\x00' * 271828),
- fill_username=fill_username, fill_oauth_scope=fill_oauth_scope)
- response_future = stub.UnaryCall.future(
- request, credentials=call_credentials)
- response = response_future.result()
- _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
- return response
+ if response.payload.type is not expected_type:
+ raise ValueError('expected payload type %s, got %s' %
+ (expected_type, type(response.payload.type)))
+ elif len(response.payload.body) != expected_length:
+ raise ValueError('expected payload body size %d, got %d' %
+ (expected_length, len(response.payload.body)))
+
+
+def _large_unary_common_behavior(stub, fill_username, fill_oauth_scope,
+ call_credentials):
+ size = 314159
+ request = messages_pb2.SimpleRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_size=size,
+ payload=messages_pb2.Payload(body=b'\x00' * 271828),
+ fill_username=fill_username,
+ fill_oauth_scope=fill_oauth_scope)
+ response_future = stub.UnaryCall.future(
+ request, credentials=call_credentials)
+ response = response_future.result()
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE, size)
+ return response
def _empty_unary(stub):
- response = stub.EmptyCall(empty_pb2.Empty())
- if not isinstance(response, empty_pb2.Empty):
- raise TypeError(
- 'response is of type "%s", not empty_pb2.Empty!', type(response))
+ response = stub.EmptyCall(empty_pb2.Empty())
+ if not isinstance(response, empty_pb2.Empty):
+ raise TypeError('response is of type "%s", not empty_pb2.Empty!',
+ type(response))
def _large_unary(stub):
- _large_unary_common_behavior(stub, False, False, None)
+ _large_unary_common_behavior(stub, False, False, None)
def _client_streaming(stub):
- payload_body_sizes = (27182, 8, 1828, 45904,)
- payloads = (
- messages_pb2.Payload(body=b'\x00' * size)
- for size in payload_body_sizes)
- requests = (
- messages_pb2.StreamingInputCallRequest(payload=payload)
- for payload in payloads)
- response = stub.StreamingInputCall(requests)
- if response.aggregated_payload_size != 74922:
- raise ValueError(
- 'incorrect size %d!' % response.aggregated_payload_size)
+ payload_body_sizes = (
+ 27182,
+ 8,
+ 1828,
+ 45904,)
+ payloads = (messages_pb2.Payload(body=b'\x00' * size)
+ for size in payload_body_sizes)
+ requests = (messages_pb2.StreamingInputCallRequest(payload=payload)
+ for payload in payloads)
+ response = stub.StreamingInputCall(requests)
+ if response.aggregated_payload_size != 74922:
+ raise ValueError('incorrect size %d!' %
+ response.aggregated_payload_size)
def _server_streaming(stub):
- sizes = (31415, 9, 2653, 58979,)
-
- request = messages_pb2.StreamingOutputCallRequest(
- response_type=messages_pb2.COMPRESSABLE,
- response_parameters=(
- messages_pb2.ResponseParameters(size=sizes[0]),
- messages_pb2.ResponseParameters(size=sizes[1]),
- messages_pb2.ResponseParameters(size=sizes[2]),
- messages_pb2.ResponseParameters(size=sizes[3]),
- )
- )
- response_iterator = stub.StreamingOutputCall(request)
- for index, response in enumerate(response_iterator):
- _validate_payload_type_and_length(
- response, messages_pb2.COMPRESSABLE, sizes[index])
+ sizes = (
+ 31415,
+ 9,
+ 2653,
+ 58979,)
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(
+ messages_pb2.ResponseParameters(size=sizes[0]),
+ messages_pb2.ResponseParameters(size=sizes[1]),
+ messages_pb2.ResponseParameters(size=sizes[2]),
+ messages_pb2.ResponseParameters(size=sizes[3]),))
+ response_iterator = stub.StreamingOutputCall(request)
+ for index, response in enumerate(response_iterator):
+ _validate_payload_type_and_length(response, messages_pb2.COMPRESSABLE,
+ sizes[index])
class _Pipe(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._values = []
- self._open = True
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._values = []
+ self._open = True
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- return self.next()
+ def __next__(self):
+ return self.next()
- def next(self):
- with self._condition:
- while not self._values and self._open:
- self._condition.wait()
- if self._values:
- return self._values.pop(0)
- else:
- raise StopIteration()
+ def next(self):
+ with self._condition:
+ while not self._values and self._open:
+ self._condition.wait()
+ if self._values:
+ return self._values.pop(0)
+ else:
+ raise StopIteration()
- def add(self, value):
- with self._condition:
- self._values.append(value)
- self._condition.notify()
+ def add(self, value):
+ with self._condition:
+ self._values.append(value)
+ self._condition.notify()
- def close(self):
- with self._condition:
- self._open = False
- self._condition.notify()
+ def close(self):
+ with self._condition:
+ self._open = False
+ self._condition.notify()
- def __enter__(self):
- return self
+ def __enter__(self):
+ return self
- def __exit__(self, type, value, traceback):
- self.close()
+ def __exit__(self, type, value, traceback):
+ self.close()
def _ping_pong(stub):
- request_response_sizes = (31415, 9, 2653, 58979,)
- request_payload_sizes = (27182, 8, 1828, 45904,)
-
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe)
- for response_size, payload_size in zip(
- request_response_sizes, request_payload_sizes):
- request = messages_pb2.StreamingOutputCallRequest(
- response_type=messages_pb2.COMPRESSABLE,
- response_parameters=(
- messages_pb2.ResponseParameters(size=response_size),),
- payload=messages_pb2.Payload(body=b'\x00' * payload_size))
- pipe.add(request)
- response = next(response_iterator)
- _validate_payload_type_and_length(
- response, messages_pb2.COMPRESSABLE, response_size)
+ request_response_sizes = (
+ 31415,
+ 9,
+ 2653,
+ 58979,)
+ request_payload_sizes = (
+ 27182,
+ 8,
+ 1828,
+ 45904,)
+
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe)
+ for response_size, payload_size in zip(request_response_sizes,
+ request_payload_sizes):
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(
+ messages_pb2.ResponseParameters(size=response_size),),
+ payload=messages_pb2.Payload(body=b'\x00' * payload_size))
+ pipe.add(request)
+ response = next(response_iterator)
+ _validate_payload_type_and_length(
+ response, messages_pb2.COMPRESSABLE, response_size)
def _cancel_after_begin(stub):
- with _Pipe() as pipe:
- response_future = stub.StreamingInputCall.future(pipe)
- response_future.cancel()
- if not response_future.cancelled():
- raise ValueError('expected cancelled method to return True')
- if response_future.code() is not grpc.StatusCode.CANCELLED:
- raise ValueError('expected status code CANCELLED')
+ with _Pipe() as pipe:
+ response_future = stub.StreamingInputCall.future(pipe)
+ response_future.cancel()
+ if not response_future.cancelled():
+ raise ValueError('expected cancelled method to return True')
+ if response_future.code() is not grpc.StatusCode.CANCELLED:
+ raise ValueError('expected status code CANCELLED')
def _cancel_after_first_response(stub):
- request_response_sizes = (31415, 9, 2653, 58979,)
- request_payload_sizes = (27182, 8, 1828, 45904,)
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe)
-
- response_size = request_response_sizes[0]
- payload_size = request_payload_sizes[0]
- request = messages_pb2.StreamingOutputCallRequest(
- response_type=messages_pb2.COMPRESSABLE,
- response_parameters=(
- messages_pb2.ResponseParameters(size=response_size),),
- payload=messages_pb2.Payload(body=b'\x00' * payload_size))
- pipe.add(request)
- response = next(response_iterator)
- # We test the contents of `response` in the Ping Pong test - don't check
- # them here.
- response_iterator.cancel()
-
- try:
- next(response_iterator)
- except grpc.RpcError as rpc_error:
- if rpc_error.code() is not grpc.StatusCode.CANCELLED:
- raise
- else:
- raise ValueError('expected call to be cancelled')
+ request_response_sizes = (
+ 31415,
+ 9,
+ 2653,
+ 58979,)
+ request_payload_sizes = (
+ 27182,
+ 8,
+ 1828,
+ 45904,)
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe)
+
+ response_size = request_response_sizes[0]
+ payload_size = request_payload_sizes[0]
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(
+ messages_pb2.ResponseParameters(size=response_size),),
+ payload=messages_pb2.Payload(body=b'\x00' * payload_size))
+ pipe.add(request)
+ response = next(response_iterator)
+ # We test the contents of `response` in the Ping Pong test - don't check
+ # them here.
+ response_iterator.cancel()
+
+ try:
+ next(response_iterator)
+ except grpc.RpcError as rpc_error:
+ if rpc_error.code() is not grpc.StatusCode.CANCELLED:
+ raise
+ else:
+ raise ValueError('expected call to be cancelled')
def _timeout_on_sleeping_server(stub):
- request_payload_size = 27182
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)
-
- request = messages_pb2.StreamingOutputCallRequest(
- response_type=messages_pb2.COMPRESSABLE,
- payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))
- pipe.add(request)
- try:
- next(response_iterator)
- except grpc.RpcError as rpc_error:
- if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
- raise
- else:
- raise ValueError('expected call to exceed deadline')
+ request_payload_size = 27182
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe, timeout=0.001)
+
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ payload=messages_pb2.Payload(body=b'\x00' * request_payload_size))
+ pipe.add(request)
+ try:
+ next(response_iterator)
+ except grpc.RpcError as rpc_error:
+ if rpc_error.code() is not grpc.StatusCode.DEADLINE_EXCEEDED:
+ raise
+ else:
+ raise ValueError('expected call to exceed deadline')
def _empty_stream(stub):
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe)
- pipe.close()
- try:
- next(response_iterator)
- raise ValueError('expected exactly 0 responses')
- except StopIteration:
- pass
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe)
+ pipe.close()
+ try:
+ next(response_iterator)
+ raise ValueError('expected exactly 0 responses')
+ except StopIteration:
+ pass
def _status_code_and_message(stub):
- details = 'test status message'
- code = 2
- status = grpc.StatusCode.UNKNOWN # code = 2
-
- # Test with a UnaryCall
- request = messages_pb2.SimpleRequest(
- response_type=messages_pb2.COMPRESSABLE,
- response_size=1,
- payload=messages_pb2.Payload(body=b'\x00'),
- response_status=messages_pb2.EchoStatus(code=code, message=details)
- )
- response_future = stub.UnaryCall.future(request)
- _validate_status_code_and_details(response_future, status, details)
-
- # Test with a FullDuplexCall
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe)
- request = messages_pb2.StreamingOutputCallRequest(
+ details = 'test status message'
+ code = 2
+ status = grpc.StatusCode.UNKNOWN # code = 2
+
+ # Test with a UnaryCall
+ request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE,
- response_parameters=(
- messages_pb2.ResponseParameters(size=1),),
+ response_size=1,
payload=messages_pb2.Payload(body=b'\x00'),
- response_status=messages_pb2.EchoStatus(code=code, message=details))
- pipe.add(request) # sends the initial request.
- # Dropping out of with block closes the pipe
- _validate_status_code_and_details(response_iterator, status, details)
+ response_status=messages_pb2.EchoStatus(
+ code=code, message=details))
+ response_future = stub.UnaryCall.future(request)
+ _validate_status_code_and_details(response_future, status, details)
+
+ # Test with a FullDuplexCall
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe)
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(messages_pb2.ResponseParameters(size=1),),
+ payload=messages_pb2.Payload(body=b'\x00'),
+ response_status=messages_pb2.EchoStatus(
+ code=code, message=details))
+ pipe.add(request) # sends the initial request.
+ # Dropping out of with block closes the pipe
+ _validate_status_code_and_details(response_iterator, status, details)
def _unimplemented_method(test_service_stub):
- response_future = (
- test_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
- _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
+ response_future = (
+ test_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
+ _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
def _unimplemented_service(unimplemented_service_stub):
- response_future = (
- unimplemented_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
- _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
+ response_future = (
+ unimplemented_service_stub.UnimplementedCall.future(empty_pb2.Empty()))
+ _expect_status_code(response_future, grpc.StatusCode.UNIMPLEMENTED)
def _custom_metadata(stub):
- initial_metadata_value = "test_initial_metadata_value"
- trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
- metadata = (
- (_INITIAL_METADATA_KEY, initial_metadata_value),
- (_TRAILING_METADATA_KEY, trailing_metadata_value))
-
- def _validate_metadata(response):
- initial_metadata = dict(response.initial_metadata())
- if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
- raise ValueError(
- 'expected initial metadata %s, got %s' % (
- initial_metadata_value, initial_metadata[_INITIAL_METADATA_KEY]))
- trailing_metadata = dict(response.trailing_metadata())
- if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
- raise ValueError(
- 'expected trailing metadata %s, got %s' % (
- trailing_metadata_value, initial_metadata[_TRAILING_METADATA_KEY]))
-
- # Testing with UnaryCall
- request = messages_pb2.SimpleRequest(
- response_type=messages_pb2.COMPRESSABLE,
- response_size=1,
- payload=messages_pb2.Payload(body=b'\x00'))
- response_future = stub.UnaryCall.future(request, metadata=metadata)
- _validate_metadata(response_future)
-
- # Testing with FullDuplexCall
- with _Pipe() as pipe:
- response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
- request = messages_pb2.StreamingOutputCallRequest(
+ initial_metadata_value = "test_initial_metadata_value"
+ trailing_metadata_value = "\x0a\x0b\x0a\x0b\x0a\x0b"
+ metadata = ((_INITIAL_METADATA_KEY, initial_metadata_value),
+ (_TRAILING_METADATA_KEY, trailing_metadata_value))
+
+ def _validate_metadata(response):
+ initial_metadata = dict(response.initial_metadata())
+ if initial_metadata[_INITIAL_METADATA_KEY] != initial_metadata_value:
+ raise ValueError('expected initial metadata %s, got %s' %
+ (initial_metadata_value,
+ initial_metadata[_INITIAL_METADATA_KEY]))
+ trailing_metadata = dict(response.trailing_metadata())
+ if trailing_metadata[_TRAILING_METADATA_KEY] != trailing_metadata_value:
+ raise ValueError('expected trailing metadata %s, got %s' %
+ (trailing_metadata_value,
+ initial_metadata[_TRAILING_METADATA_KEY]))
+
+ # Testing with UnaryCall
+ request = messages_pb2.SimpleRequest(
response_type=messages_pb2.COMPRESSABLE,
- response_parameters=(
- messages_pb2.ResponseParameters(size=1),))
- pipe.add(request) # Sends the request
- next(response_iterator) # Causes server to send trailing metadata
- # Dropping out of the with block closes the pipe
- _validate_metadata(response_iterator)
+ response_size=1,
+ payload=messages_pb2.Payload(body=b'\x00'))
+ response_future = stub.UnaryCall.future(request, metadata=metadata)
+ _validate_metadata(response_future)
+
+ # Testing with FullDuplexCall
+ with _Pipe() as pipe:
+ response_iterator = stub.FullDuplexCall(pipe, metadata=metadata)
+ request = messages_pb2.StreamingOutputCallRequest(
+ response_type=messages_pb2.COMPRESSABLE,
+ response_parameters=(messages_pb2.ResponseParameters(size=1),))
+ pipe.add(request) # Sends the request
+ next(response_iterator) # Causes server to send trailing metadata
+ # Dropping out of the with block closes the pipe
+ _validate_metadata(response_iterator)
+
def _compute_engine_creds(stub, args):
- response = _large_unary_common_behavior(stub, True, True, None)
- if args.default_service_account != response.username:
- raise ValueError(
- 'expected username %s, got %s' % (
- args.default_service_account, response.username))
+ response = _large_unary_common_behavior(stub, True, True, None)
+ if args.default_service_account != response.username:
+ raise ValueError('expected username %s, got %s' %
+ (args.default_service_account, response.username))
def _oauth2_auth_token(stub, args):
- json_key_filename = os.environ[
- oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
- wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
- response = _large_unary_common_behavior(stub, True, True, None)
- if wanted_email != response.username:
- raise ValueError(
- 'expected username %s, got %s' % (wanted_email, response.username))
- if args.oauth_scope.find(response.oauth_scope) == -1:
- raise ValueError(
- 'expected to find oauth scope "{}" in received "{}"'.format(
- response.oauth_scope, args.oauth_scope))
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ response = _large_unary_common_behavior(stub, True, True, None)
+ if wanted_email != response.username:
+ raise ValueError('expected username %s, got %s' %
+ (wanted_email, response.username))
+ if args.oauth_scope.find(response.oauth_scope) == -1:
+ raise ValueError('expected to find oauth scope "{}" in received "{}"'.
+ format(response.oauth_scope, args.oauth_scope))
def _jwt_token_creds(stub, args):
- json_key_filename = os.environ[
- oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
- wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
- response = _large_unary_common_behavior(stub, True, False, None)
- if wanted_email != response.username:
- raise ValueError(
- 'expected username %s, got %s' % (wanted_email, response.username))
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ response = _large_unary_common_behavior(stub, True, False, None)
+ if wanted_email != response.username:
+ raise ValueError('expected username %s, got %s' %
+ (wanted_email, response.username))
def _per_rpc_creds(stub, args):
- json_key_filename = os.environ[
- oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
- wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
- credentials = oauth2client_client.GoogleCredentials.get_application_default()
- scoped_credentials = credentials.create_scoped([args.oauth_scope])
- # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
- # remaining use of the Beta API.
- call_credentials = implementations.google_call_credentials(
- scoped_credentials)
- response = _large_unary_common_behavior(stub, True, False, call_credentials)
- if wanted_email != response.username:
- raise ValueError(
- 'expected username %s, got %s' % (wanted_email, response.username))
+ json_key_filename = os.environ[
+ oauth2client_client.GOOGLE_APPLICATION_CREDENTIALS]
+ wanted_email = json.load(open(json_key_filename, 'rb'))['client_email']
+ credentials = oauth2client_client.GoogleCredentials.get_application_default(
+ )
+ scoped_credentials = credentials.create_scoped([args.oauth_scope])
+ # TODO(https://github.com/grpc/grpc/issues/6799): Eliminate this last
+ # remaining use of the Beta API.
+ call_credentials = implementations.google_call_credentials(
+ scoped_credentials)
+ response = _large_unary_common_behavior(stub, True, False, call_credentials)
+ if wanted_email != response.username:
+ raise ValueError('expected username %s, got %s' %
+ (wanted_email, response.username))
@enum.unique
class TestCase(enum.Enum):
- EMPTY_UNARY = 'empty_unary'
- LARGE_UNARY = 'large_unary'
- SERVER_STREAMING = 'server_streaming'
- CLIENT_STREAMING = 'client_streaming'
- PING_PONG = 'ping_pong'
- CANCEL_AFTER_BEGIN = 'cancel_after_begin'
- CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
- EMPTY_STREAM = 'empty_stream'
- STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
- UNIMPLEMENTED_METHOD = 'unimplemented_method'
- UNIMPLEMENTED_SERVICE = 'unimplemented_service'
- CUSTOM_METADATA = "custom_metadata"
- COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
- OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
- JWT_TOKEN_CREDS = 'jwt_token_creds'
- PER_RPC_CREDS = 'per_rpc_creds'
- TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
-
- def test_interoperability(self, stub, args):
- if self is TestCase.EMPTY_UNARY:
- _empty_unary(stub)
- elif self is TestCase.LARGE_UNARY:
- _large_unary(stub)
- elif self is TestCase.SERVER_STREAMING:
- _server_streaming(stub)
- elif self is TestCase.CLIENT_STREAMING:
- _client_streaming(stub)
- elif self is TestCase.PING_PONG:
- _ping_pong(stub)
- elif self is TestCase.CANCEL_AFTER_BEGIN:
- _cancel_after_begin(stub)
- elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
- _cancel_after_first_response(stub)
- elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
- _timeout_on_sleeping_server(stub)
- elif self is TestCase.EMPTY_STREAM:
- _empty_stream(stub)
- elif self is TestCase.STATUS_CODE_AND_MESSAGE:
- _status_code_and_message(stub)
- elif self is TestCase.UNIMPLEMENTED_METHOD:
- _unimplemented_method(stub)
- elif self is TestCase.UNIMPLEMENTED_SERVICE:
- _unimplemented_service(stub)
- elif self is TestCase.CUSTOM_METADATA:
- _custom_metadata(stub)
- elif self is TestCase.COMPUTE_ENGINE_CREDS:
- _compute_engine_creds(stub, args)
- elif self is TestCase.OAUTH2_AUTH_TOKEN:
- _oauth2_auth_token(stub, args)
- elif self is TestCase.JWT_TOKEN_CREDS:
- _jwt_token_creds(stub, args)
- elif self is TestCase.PER_RPC_CREDS:
- _per_rpc_creds(stub, args)
- else:
- raise NotImplementedError('Test case "%s" not implemented!' % self.name)
+ EMPTY_UNARY = 'empty_unary'
+ LARGE_UNARY = 'large_unary'
+ SERVER_STREAMING = 'server_streaming'
+ CLIENT_STREAMING = 'client_streaming'
+ PING_PONG = 'ping_pong'
+ CANCEL_AFTER_BEGIN = 'cancel_after_begin'
+ CANCEL_AFTER_FIRST_RESPONSE = 'cancel_after_first_response'
+ EMPTY_STREAM = 'empty_stream'
+ STATUS_CODE_AND_MESSAGE = 'status_code_and_message'
+ UNIMPLEMENTED_METHOD = 'unimplemented_method'
+ UNIMPLEMENTED_SERVICE = 'unimplemented_service'
+ CUSTOM_METADATA = "custom_metadata"
+ COMPUTE_ENGINE_CREDS = 'compute_engine_creds'
+ OAUTH2_AUTH_TOKEN = 'oauth2_auth_token'
+ JWT_TOKEN_CREDS = 'jwt_token_creds'
+ PER_RPC_CREDS = 'per_rpc_creds'
+ TIMEOUT_ON_SLEEPING_SERVER = 'timeout_on_sleeping_server'
+
+ def test_interoperability(self, stub, args):
+ if self is TestCase.EMPTY_UNARY:
+ _empty_unary(stub)
+ elif self is TestCase.LARGE_UNARY:
+ _large_unary(stub)
+ elif self is TestCase.SERVER_STREAMING:
+ _server_streaming(stub)
+ elif self is TestCase.CLIENT_STREAMING:
+ _client_streaming(stub)
+ elif self is TestCase.PING_PONG:
+ _ping_pong(stub)
+ elif self is TestCase.CANCEL_AFTER_BEGIN:
+ _cancel_after_begin(stub)
+ elif self is TestCase.CANCEL_AFTER_FIRST_RESPONSE:
+ _cancel_after_first_response(stub)
+ elif self is TestCase.TIMEOUT_ON_SLEEPING_SERVER:
+ _timeout_on_sleeping_server(stub)
+ elif self is TestCase.EMPTY_STREAM:
+ _empty_stream(stub)
+ elif self is TestCase.STATUS_CODE_AND_MESSAGE:
+ _status_code_and_message(stub)
+ elif self is TestCase.UNIMPLEMENTED_METHOD:
+ _unimplemented_method(stub)
+ elif self is TestCase.UNIMPLEMENTED_SERVICE:
+ _unimplemented_service(stub)
+ elif self is TestCase.CUSTOM_METADATA:
+ _custom_metadata(stub)
+ elif self is TestCase.COMPUTE_ENGINE_CREDS:
+ _compute_engine_creds(stub, args)
+ elif self is TestCase.OAUTH2_AUTH_TOKEN:
+ _oauth2_auth_token(stub, args)
+ elif self is TestCase.JWT_TOKEN_CREDS:
+ _jwt_token_creds(stub, args)
+ elif self is TestCase.PER_RPC_CREDS:
+ _per_rpc_creds(stub, args)
+ else:
+ raise NotImplementedError('Test case "%s" not implemented!' %
+ self.name)
diff --git a/src/python/grpcio_tests/tests/interop/resources.py b/src/python/grpcio_tests/tests/interop/resources.py
index c424385cf6..2ec2eb92b4 100644
--- a/src/python/grpcio_tests/tests/interop/resources.py
+++ b/src/python/grpcio_tests/tests/interop/resources.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Constants and functions for data used in interoperability testing."""
import argparse
@@ -40,22 +39,22 @@ _CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem'
def test_root_certificates():
- return pkg_resources.resource_string(
- __name__, _ROOT_CERTIFICATES_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__,
+ _ROOT_CERTIFICATES_RESOURCE_PATH)
def private_key():
- return pkg_resources.resource_string(__name__, _PRIVATE_KEY_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__, _PRIVATE_KEY_RESOURCE_PATH)
def certificate_chain():
- return pkg_resources.resource_string(
- __name__, _CERTIFICATE_CHAIN_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__,
+ _CERTIFICATE_CHAIN_RESOURCE_PATH)
def parse_bool(value):
- if value == 'true':
- return True
- if value == 'false':
- return False
- raise argparse.ArgumentTypeError('Only true/false allowed')
+ if value == 'true':
+ return True
+ if value == 'false':
+ return False
+ raise argparse.ArgumentTypeError('Only true/false allowed')
diff --git a/src/python/grpcio_tests/tests/interop/server.py b/src/python/grpcio_tests/tests/interop/server.py
index 1ae83bc57d..65f1604eb8 100644
--- a/src/python/grpcio_tests/tests/interop/server.py
+++ b/src/python/grpcio_tests/tests/interop/server.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""The Python implementation of the GRPC interoperability test server."""
import argparse
@@ -44,34 +43,36 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
def serve():
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--port', help='the port on which to serve', type=int)
- parser.add_argument(
- '--use_tls', help='require a secure connection',
- default=False, type=resources.parse_bool)
- args = parser.parse_args()
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--port', help='the port on which to serve', type=int)
+ parser.add_argument(
+ '--use_tls',
+ help='require a secure connection',
+ default=False,
+ type=resources.parse_bool)
+ args = parser.parse_args()
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
+ test_pb2.add_TestServiceServicer_to_server(methods.TestService(), server)
+ if args.use_tls:
+ private_key = resources.private_key()
+ certificate_chain = resources.certificate_chain()
+ credentials = grpc.ssl_server_credentials((
+ (private_key, certificate_chain),))
+ server.add_secure_port('[::]:{}'.format(args.port), credentials)
+ else:
+ server.add_insecure_port('[::]:{}'.format(args.port))
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
- test_pb2.add_TestServiceServicer_to_server(methods.TestService(), server)
- if args.use_tls:
- private_key = resources.private_key()
- certificate_chain = resources.certificate_chain()
- credentials = grpc.ssl_server_credentials(
- ((private_key, certificate_chain),))
- server.add_secure_port('[::]:{}'.format(args.port), credentials)
- else:
- server.add_insecure_port('[::]:{}'.format(args.port))
+ server.start()
+ logging.info('Server serving.')
+ try:
+ while True:
+ time.sleep(_ONE_DAY_IN_SECONDS)
+ except BaseException as e:
+ logging.info('Caught exception "%s"; stopping server...', e)
+ server.stop(None)
+ logging.info('Server stopped; exiting.')
- server.start()
- logging.info('Server serving.')
- try:
- while True:
- time.sleep(_ONE_DAY_IN_SECONDS)
- except BaseException as e:
- logging.info('Caught exception "%s"; stopping server...', e)
- server.stop(None)
- logging.info('Server stopped; exiting.')
if __name__ == '__main__':
- serve()
+ serve()
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
index 7ca2bcff38..ae5da2c3db 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/_python_plugin_test.py
@@ -58,436 +58,440 @@ ADD_SERVICER_TO_SERVER_IDENTIFIER = 'add_TestServiceServicer_to_server'
class _ServicerMethods(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._paused = False
- self._fail = False
-
- @contextlib.contextmanager
- def pause(self): # pylint: disable=invalid-name
- with self._condition:
- self._paused = True
- yield
- with self._condition:
- self._paused = False
- self._condition.notify_all()
-
- @contextlib.contextmanager
- def fail(self): # pylint: disable=invalid-name
- with self._condition:
- self._fail = True
- yield
- with self._condition:
- self._fail = False
-
- def _control(self): # pylint: disable=invalid-name
- with self._condition:
- if self._fail:
- raise ValueError()
- while self._paused:
- self._condition.wait()
-
- def UnaryCall(self, request, unused_rpc_context):
- response = response_pb2.SimpleResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * request.response_size
- self._control()
- return response
-
- def StreamingOutputCall(self, request, unused_rpc_context):
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
- self._control()
- yield response
-
- def StreamingInputCall(self, request_iter, unused_rpc_context):
- response = response_pb2.StreamingInputCallResponse()
- aggregated_payload_size = 0
- for request in request_iter:
- aggregated_payload_size += len(request.payload.payload_compressable)
- response.aggregated_payload_size = aggregated_payload_size
- self._control()
- return response
-
- def FullDuplexCall(self, request_iter, unused_rpc_context):
- for request in request_iter:
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
- self._control()
- yield response
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._paused = False
+ self._fail = False
+
+ @contextlib.contextmanager
+ def pause(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._paused = True
+ yield
+ with self._condition:
+ self._paused = False
+ self._condition.notify_all()
- def HalfDuplexCall(self, request_iter, unused_rpc_context):
- responses = []
- for request in request_iter:
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
+ @contextlib.contextmanager
+ def fail(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._fail = True
+ yield
+ with self._condition:
+ self._fail = False
+
+ def _control(self): # pylint: disable=invalid-name
+ with self._condition:
+ if self._fail:
+ raise ValueError()
+ while self._paused:
+ self._condition.wait()
+
+ def UnaryCall(self, request, unused_rpc_context):
+ response = response_pb2.SimpleResponse()
response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
+ response.payload.payload_compressable = 'a' * request.response_size
+ self._control()
+ return response
+
+ def StreamingOutputCall(self, request, unused_rpc_context):
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def StreamingInputCall(self, request_iter, unused_rpc_context):
+ response = response_pb2.StreamingInputCallResponse()
+ aggregated_payload_size = 0
+ for request in request_iter:
+ aggregated_payload_size += len(request.payload.payload_compressable)
+ response.aggregated_payload_size = aggregated_payload_size
self._control()
- responses.append(response)
- for response in responses:
- yield response
+ return response
+
+ def FullDuplexCall(self, request_iter, unused_rpc_context):
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def HalfDuplexCall(self, request_iter, unused_rpc_context):
+ responses = []
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ responses.append(response)
+ for response in responses:
+ yield response
class _Service(
- collections.namedtuple(
- '_Service', ('servicer_methods', 'server', 'stub',))):
- """A live and running service.
+ collections.namedtuple('_Service', (
+ 'servicer_methods',
+ 'server',
+ 'stub',))):
+ """A live and running service.
Attributes:
servicer_methods: The _ServicerMethods servicing RPCs.
server: The grpc.Server servicing RPCs.
stub: A stub on which to invoke RPCs.
"""
-
+
def _CreateService():
- """Provides a servicer backend and a stub.
+ """Provides a servicer backend and a stub.
Returns:
A _Service with which to test RPCs.
"""
- servicer_methods = _ServicerMethods()
+ servicer_methods = _ServicerMethods()
- class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
- def UnaryCall(self, request, context):
- return servicer_methods.UnaryCall(request, context)
+ def UnaryCall(self, request, context):
+ return servicer_methods.UnaryCall(request, context)
- def StreamingOutputCall(self, request, context):
- return servicer_methods.StreamingOutputCall(request, context)
+ def StreamingOutputCall(self, request, context):
+ return servicer_methods.StreamingOutputCall(request, context)
- def StreamingInputCall(self, request_iter, context):
- return servicer_methods.StreamingInputCall(request_iter, context)
+ def StreamingInputCall(self, request_iter, context):
+ return servicer_methods.StreamingInputCall(request_iter, context)
- def FullDuplexCall(self, request_iter, context):
- return servicer_methods.FullDuplexCall(request_iter, context)
+ def FullDuplexCall(self, request_iter, context):
+ return servicer_methods.FullDuplexCall(request_iter, context)
- def HalfDuplexCall(self, request_iter, context):
- return servicer_methods.HalfDuplexCall(request_iter, context)
+ def HalfDuplexCall(self, request_iter, context):
+ return servicer_methods.HalfDuplexCall(request_iter, context)
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
- getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
- return _Service(servicer_methods, server, stub)
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(servicer_methods, server, stub)
def _CreateIncompleteService():
- """Provides a servicer backend that fails to implement methods and its stub.
+ """Provides a servicer backend that fails to implement methods and its stub.
Returns:
A _Service with which to test RPCs. The returned _Service's
servicer_methods implements none of the methods required of it.
"""
- class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
- pass
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ pass
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
- getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
- return _Service(None, server, stub)
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER)(Servicer(), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = getattr(service_pb2, STUB_IDENTIFIER)(channel)
+ return _Service(None, server, stub)
def _streaming_input_request_iterator():
- for _ in range(3):
- request = request_pb2.StreamingInputCallRequest()
- request.payload.payload_type = payload_pb2.COMPRESSABLE
- request.payload.payload_compressable = 'a'
- yield request
+ for _ in range(3):
+ request = request_pb2.StreamingInputCallRequest()
+ request.payload.payload_type = payload_pb2.COMPRESSABLE
+ request.payload.payload_compressable = 'a'
+ yield request
def _streaming_output_request():
- request = request_pb2.StreamingOutputCallRequest()
- sizes = [1, 2, 3]
- request.response_parameters.add(size=sizes[0], interval_us=0)
- request.response_parameters.add(size=sizes[1], interval_us=0)
- request.response_parameters.add(size=sizes[2], interval_us=0)
- return request
+ request = request_pb2.StreamingOutputCallRequest()
+ sizes = [1, 2, 3]
+ request.response_parameters.add(size=sizes[0], interval_us=0)
+ request.response_parameters.add(size=sizes[1], interval_us=0)
+ request.response_parameters.add(size=sizes[2], interval_us=0)
+ return request
def _full_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=2, interval_us=0)
- request.response_parameters.add(size=3, interval_us=0)
- yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
class PythonPluginTest(unittest.TestCase):
- """Test case for the gRPC Python protoc-plugin.
+ """Test case for the gRPC Python protoc-plugin.
While reading these tests, remember that the futures API
(`stub.method.future()`) only gives futures for the *response-unary*
methods and does not exist for response-streaming methods.
"""
- def testImportAttributes(self):
- # check that we can access the generated module and its members.
- self.assertIsNotNone(
- getattr(service_pb2, STUB_IDENTIFIER, None))
- self.assertIsNotNone(
- getattr(service_pb2, SERVICER_IDENTIFIER, None))
- self.assertIsNotNone(
- getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
-
- def testUpDown(self):
- service = _CreateService()
- self.assertIsNotNone(service.servicer_methods)
- self.assertIsNotNone(service.server)
- self.assertIsNotNone(service.stub)
-
- def testIncompleteServicer(self):
- service = _CreateIncompleteService()
- request = request_pb2.SimpleRequest(response_size=13)
- with self.assertRaises(grpc.RpcError) as exception_context:
- service.stub.UnaryCall(request)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.UNIMPLEMENTED)
-
- def testUnaryCall(self):
- service = _CreateService()
- request = request_pb2.SimpleRequest(response_size=13)
- response = service.stub.UnaryCall(request)
- expected_response = service.servicer_methods.UnaryCall(
- request, 'not a real context!')
- self.assertEqual(expected_response, response)
-
- def testUnaryCallFuture(self):
- service = _CreateService()
- request = request_pb2.SimpleRequest(response_size=13)
- # Check that the call does not block waiting for the server to respond.
- with service.servicer_methods.pause():
- response_future = service.stub.UnaryCall.future(request)
- response = response_future.result()
- expected_response = service.servicer_methods.UnaryCall(
- request, 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testUnaryCallFutureExpired(self):
- service = _CreateService()
- request = request_pb2.SimpleRequest(response_size=13)
- with service.servicer_methods.pause():
- response_future = service.stub.UnaryCall.future(
- request, timeout=test_constants.SHORT_TIMEOUT)
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
- self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
-
- def testUnaryCallFutureCancelled(self):
- service = _CreateService()
- request = request_pb2.SimpleRequest(response_size=13)
- with service.servicer_methods.pause():
- response_future = service.stub.UnaryCall.future(request)
- response_future.cancel()
- self.assertTrue(response_future.cancelled())
- self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED)
-
- def testUnaryCallFutureFailed(self):
- service = _CreateService()
- request = request_pb2.SimpleRequest(response_size=13)
- with service.servicer_methods.fail():
- response_future = service.stub.UnaryCall.future(request)
- self.assertIsNotNone(response_future.exception())
- self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
-
- def testStreamingOutputCall(self):
- service = _CreateService()
- request = _streaming_output_request()
- responses = service.stub.StreamingOutputCall(request)
- expected_responses = service.servicer_methods.StreamingOutputCall(
- request, 'not a real RpcContext!')
- for expected_response, response in moves.zip_longest(
- expected_responses, responses):
- self.assertEqual(expected_response, response)
-
- def testStreamingOutputCallExpired(self):
- service = _CreateService()
- request = _streaming_output_request()
- with service.servicer_methods.pause():
- responses = service.stub.StreamingOutputCall(
- request, timeout=test_constants.SHORT_TIMEOUT)
- with self.assertRaises(grpc.RpcError) as exception_context:
- list(responses)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
-
- def testStreamingOutputCallCancelled(self):
- service = _CreateService()
- request = _streaming_output_request()
- responses = service.stub.StreamingOutputCall(request)
- next(responses)
- responses.cancel()
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(responses)
- self.assertIs(responses.code(), grpc.StatusCode.CANCELLED)
-
- def testStreamingOutputCallFailed(self):
- service = _CreateService()
- request = _streaming_output_request()
- with service.servicer_methods.fail():
- responses = service.stub.StreamingOutputCall(request)
- self.assertIsNotNone(responses)
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(responses)
- self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
-
- def testStreamingInputCall(self):
- service = _CreateService()
- response = service.stub.StreamingInputCall(
- _streaming_input_request_iterator())
- expected_response = service.servicer_methods.StreamingInputCall(
- _streaming_input_request_iterator(),
- 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testStreamingInputCallFuture(self):
- service = _CreateService()
- with service.servicer_methods.pause():
- response_future = service.stub.StreamingInputCall.future(
- _streaming_input_request_iterator())
- response = response_future.result()
- expected_response = service.servicer_methods.StreamingInputCall(
- _streaming_input_request_iterator(),
- 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testStreamingInputCallFutureExpired(self):
- service = _CreateService()
- with service.servicer_methods.pause():
- response_future = service.stub.StreamingInputCall.future(
- _streaming_input_request_iterator(),
- timeout=test_constants.SHORT_TIMEOUT)
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIsInstance(response_future.exception(), grpc.RpcError)
- self.assertIs(
- response_future.exception().code(), grpc.StatusCode.DEADLINE_EXCEEDED)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
-
- def testStreamingInputCallFutureCancelled(self):
- service = _CreateService()
- with service.servicer_methods.pause():
- response_future = service.stub.StreamingInputCall.future(
- _streaming_input_request_iterator())
- response_future.cancel()
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.result()
-
- def testStreamingInputCallFutureFailed(self):
- service = _CreateService()
- with service.servicer_methods.fail():
- response_future = service.stub.StreamingInputCall.future(
- _streaming_input_request_iterator())
- self.assertIsNotNone(response_future.exception())
- self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
-
- def testFullDuplexCall(self):
- service = _CreateService()
- responses = service.stub.FullDuplexCall(
- _full_duplex_request_iterator())
- expected_responses = service.servicer_methods.FullDuplexCall(
- _full_duplex_request_iterator(),
- 'not a real RpcContext!')
- for expected_response, response in moves.zip_longest(
- expected_responses, responses):
- self.assertEqual(expected_response, response)
-
- def testFullDuplexCallExpired(self):
- request_iterator = _full_duplex_request_iterator()
- service = _CreateService()
- with service.servicer_methods.pause():
- responses = service.stub.FullDuplexCall(
- request_iterator, timeout=test_constants.SHORT_TIMEOUT)
- with self.assertRaises(grpc.RpcError) as exception_context:
- list(responses)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
-
- def testFullDuplexCallCancelled(self):
- service = _CreateService()
- request_iterator = _full_duplex_request_iterator()
- responses = service.stub.FullDuplexCall(request_iterator)
- next(responses)
- responses.cancel()
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(responses)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.CANCELLED)
-
- def testFullDuplexCallFailed(self):
- request_iterator = _full_duplex_request_iterator()
- service = _CreateService()
- with service.servicer_methods.fail():
- responses = service.stub.FullDuplexCall(request_iterator)
- with self.assertRaises(grpc.RpcError) as exception_context:
+ def testImportAttributes(self):
+ # check that we can access the generated module and its members.
+ self.assertIsNotNone(getattr(service_pb2, STUB_IDENTIFIER, None))
+ self.assertIsNotNone(getattr(service_pb2, SERVICER_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(service_pb2, ADD_SERVICER_TO_SERVER_IDENTIFIER, None))
+
+ def testUpDown(self):
+ service = _CreateService()
+ self.assertIsNotNone(service.servicer_methods)
+ self.assertIsNotNone(service.server)
+ self.assertIsNotNone(service.stub)
+
+ def testIncompleteServicer(self):
+ service = _CreateIncompleteService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ service.stub.UnaryCall(request)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.UNIMPLEMENTED)
+
+ def testUnaryCall(self):
+ service = _CreateService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ response = service.stub.UnaryCall(request)
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real context!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFuture(self):
+ service = _CreateService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ # Check that the call does not block waiting for the server to respond.
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response = response_future.result()
+ expected_response = service.servicer_methods.UnaryCall(
+ request, 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testUnaryCallFutureExpired(self):
+ service = _CreateService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(response_future.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testUnaryCallFutureCancelled(self):
+ service = _CreateService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.pause():
+ response_future = service.stub.UnaryCall.future(request)
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ self.assertIs(response_future.code(), grpc.StatusCode.CANCELLED)
+
+ def testUnaryCallFutureFailed(self):
+ service = _CreateService()
+ request = request_pb2.SimpleRequest(response_size=13)
+ with service.servicer_methods.fail():
+ response_future = service.stub.UnaryCall.future(request)
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testStreamingOutputCall(self):
+ service = _CreateService()
+ request = _streaming_output_request()
+ responses = service.stub.StreamingOutputCall(request)
+ expected_responses = service.servicer_methods.StreamingOutputCall(
+ request, 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(expected_responses,
+ responses):
+ self.assertEqual(expected_response, response)
+
+ def testStreamingOutputCallExpired(self):
+ service = _CreateService()
+ request = _streaming_output_request()
+ with service.servicer_methods.pause():
+ responses = service.stub.StreamingOutputCall(
+ request, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingOutputCallCancelled(self):
+ service = _CreateService()
+ request = _streaming_output_request()
+ responses = service.stub.StreamingOutputCall(request)
next(responses)
- self.assertIs(exception_context.exception.code(), grpc.StatusCode.UNKNOWN)
-
- def testHalfDuplexCall(self):
- service = _CreateService()
- def half_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=2, interval_us=0)
- request.response_parameters.add(size=3, interval_us=0)
- yield request
- responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
- expected_responses = service.servicer_methods.HalfDuplexCall(
- half_duplex_request_iterator(), 'not a real RpcContext!')
- for expected_response, response in moves.zip_longest(
- expected_responses, responses):
- self.assertEqual(expected_response, response)
-
- def testHalfDuplexCallWedged(self):
- condition = threading.Condition()
- wait_cell = [False]
- @contextlib.contextmanager
- def wait(): # pylint: disable=invalid-name
- # Where's Python 3's 'nonlocal' statement when you need it?
- with condition:
- wait_cell[0] = True
- yield
- with condition:
- wait_cell[0] = False
- condition.notify_all()
- def half_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- with condition:
- while wait_cell[0]:
- condition.wait()
- service = _CreateService()
- with wait():
- responses = service.stub.HalfDuplexCall(
- half_duplex_request_iterator(), timeout=test_constants.SHORT_TIMEOUT)
- # half-duplex waits for the client to send all info
- with self.assertRaises(grpc.RpcError) as exception_context:
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(responses.code(), grpc.StatusCode.CANCELLED)
+
+ def testStreamingOutputCallFailed(self):
+ service = _CreateService()
+ request = _streaming_output_request()
+ with service.servicer_methods.fail():
+ responses = service.stub.StreamingOutputCall(request)
+ self.assertIsNotNone(responses)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.UNKNOWN)
+
+ def testStreamingInputCall(self):
+ service = _CreateService()
+ response = service.stub.StreamingInputCall(
+ _streaming_input_request_iterator())
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(), 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFuture(self):
+ service = _CreateService()
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator())
+ response = response_future.result()
+ expected_response = service.servicer_methods.StreamingInputCall(
+ _streaming_input_request_iterator(), 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFutureExpired(self):
+ service = _CreateService()
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(),
+ timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIs(response_future.exception().code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testStreamingInputCallFutureCancelled(self):
+ service = _CreateService()
+ with service.servicer_methods.pause():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator())
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+
+ def testStreamingInputCallFutureFailed(self):
+ service = _CreateService()
+ with service.servicer_methods.fail():
+ response_future = service.stub.StreamingInputCall.future(
+ _streaming_input_request_iterator())
+ self.assertIsNotNone(response_future.exception())
+ self.assertIs(response_future.code(), grpc.StatusCode.UNKNOWN)
+
+ def testFullDuplexCall(self):
+ service = _CreateService()
+ responses = service.stub.FullDuplexCall(_full_duplex_request_iterator())
+ expected_responses = service.servicer_methods.FullDuplexCall(
+ _full_duplex_request_iterator(), 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(expected_responses,
+ responses):
+ self.assertEqual(expected_response, response)
+
+ def testFullDuplexCallExpired(self):
+ request_iterator = _full_duplex_request_iterator()
+ service = _CreateService()
+ with service.servicer_methods.pause():
+ responses = service.stub.FullDuplexCall(
+ request_iterator, timeout=test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ list(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
+
+ def testFullDuplexCallCancelled(self):
+ service = _CreateService()
+ request_iterator = _full_duplex_request_iterator()
+ responses = service.stub.FullDuplexCall(request_iterator)
next(responses)
- self.assertIs(
- exception_context.exception.code(), grpc.StatusCode.DEADLINE_EXCEEDED)
+ responses.cancel()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.CANCELLED)
+
+ def testFullDuplexCallFailed(self):
+ request_iterator = _full_duplex_request_iterator()
+ service = _CreateService()
+ with service.servicer_methods.fail():
+ responses = service.stub.FullDuplexCall(request_iterator)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.UNKNOWN)
+
+ def testHalfDuplexCall(self):
+ service = _CreateService()
+
+ def half_duplex_request_iterator():
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+
+ responses = service.stub.HalfDuplexCall(half_duplex_request_iterator())
+ expected_responses = service.servicer_methods.HalfDuplexCall(
+ half_duplex_request_iterator(), 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(expected_responses,
+ responses):
+ self.assertEqual(expected_response, response)
+
+ def testHalfDuplexCallWedged(self):
+ condition = threading.Condition()
+ wait_cell = [False]
+
+ @contextlib.contextmanager
+ def wait(): # pylint: disable=invalid-name
+ # Where's Python 3's 'nonlocal' statement when you need it?
+ with condition:
+ wait_cell[0] = True
+ yield
+ with condition:
+ wait_cell[0] = False
+ condition.notify_all()
+
+ def half_duplex_request_iterator():
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ with condition:
+ while wait_cell[0]:
+ condition.wait()
+
+ service = _CreateService()
+ with wait():
+ responses = service.stub.HalfDuplexCall(
+ half_duplex_request_iterator(),
+ timeout=test_constants.SHORT_TIMEOUT)
+ # half-duplex waits for the client to send all info
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(responses)
+ self.assertIs(exception_context.exception.code(),
+ grpc.StatusCode.DEADLINE_EXCEEDED)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
index f8ae05bb7a..bcc01f3978 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/_split_definitions_test.py
@@ -49,256 +49,264 @@ from tests.unit.framework.common import test_constants
_MESSAGES_IMPORT = b'import "messages.proto";'
+
@contextlib.contextmanager
def _system_path(path):
- old_system_path = sys.path[:]
- sys.path = sys.path[0:1] + path + sys.path[1:]
- yield
- sys.path = old_system_path
+ old_system_path = sys.path[:]
+ sys.path = sys.path[0:1] + path + sys.path[1:]
+ yield
+ sys.path = old_system_path
class DummySplitServicer(object):
- def __init__(self, request_class, response_class):
- self.request_class = request_class
- self.response_class = response_class
+ def __init__(self, request_class, response_class):
+ self.request_class = request_class
+ self.response_class = response_class
- def Call(self, request, context):
- return self.response_class()
+ def Call(self, request, context):
+ return self.response_class()
class SeparateTestMixin(object):
- def testImportAttributes(self):
- with _system_path([self.python_out_directory]):
- pb2 = importlib.import_module(self.pb2_import)
- pb2.Request
- pb2.Response
- if self.should_find_services_in_pb2:
- pb2.TestServiceServicer
- else:
- with self.assertRaises(AttributeError):
- pb2.TestServiceServicer
-
- with _system_path([self.grpc_python_out_directory]):
- pb2_grpc = importlib.import_module(self.pb2_grpc_import)
- pb2_grpc.TestServiceServicer
- with self.assertRaises(AttributeError):
- pb2_grpc.Request
- with self.assertRaises(AttributeError):
- pb2_grpc.Response
-
- def testCall(self):
- with _system_path([self.python_out_directory]):
- pb2 = importlib.import_module(self.pb2_import)
- with _system_path([self.grpc_python_out_directory]):
- pb2_grpc = importlib.import_module(self.pb2_grpc_import)
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
- pb2_grpc.add_TestServiceServicer_to_server(
- DummySplitServicer(
- pb2.Request, pb2.Response), server)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- stub = pb2_grpc.TestServiceStub(channel)
- request = pb2.Request()
- expected_response = pb2.Response()
- response = stub.Call(request)
- self.assertEqual(expected_response, response)
+ def testImportAttributes(self):
+ with _system_path([self.python_out_directory]):
+ pb2 = importlib.import_module(self.pb2_import)
+ pb2.Request
+ pb2.Response
+ if self.should_find_services_in_pb2:
+ pb2.TestServiceServicer
+ else:
+ with self.assertRaises(AttributeError):
+ pb2.TestServiceServicer
+
+ with _system_path([self.grpc_python_out_directory]):
+ pb2_grpc = importlib.import_module(self.pb2_grpc_import)
+ pb2_grpc.TestServiceServicer
+ with self.assertRaises(AttributeError):
+ pb2_grpc.Request
+ with self.assertRaises(AttributeError):
+ pb2_grpc.Response
+
+ def testCall(self):
+ with _system_path([self.python_out_directory]):
+ pb2 = importlib.import_module(self.pb2_import)
+ with _system_path([self.grpc_python_out_directory]):
+ pb2_grpc = importlib.import_module(self.pb2_grpc_import)
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ pb2_grpc.add_TestServiceServicer_to_server(
+ DummySplitServicer(pb2.Request, pb2.Response), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = pb2_grpc.TestServiceStub(channel)
+ request = pb2.Request()
+ expected_response = pb2.Response()
+ response = stub.Call(request)
+ self.assertEqual(expected_response, response)
class CommonTestMixin(object):
- def testImportAttributes(self):
- with _system_path([self.python_out_directory]):
- pb2 = importlib.import_module(self.pb2_import)
- pb2.Request
- pb2.Response
- if self.should_find_services_in_pb2:
- pb2.TestServiceServicer
- else:
- with self.assertRaises(AttributeError):
- pb2.TestServiceServicer
-
- with _system_path([self.grpc_python_out_directory]):
- pb2_grpc = importlib.import_module(self.pb2_grpc_import)
- pb2_grpc.TestServiceServicer
- with self.assertRaises(AttributeError):
- pb2_grpc.Request
- with self.assertRaises(AttributeError):
- pb2_grpc.Response
-
- def testCall(self):
- with _system_path([self.python_out_directory]):
- pb2 = importlib.import_module(self.pb2_import)
- with _system_path([self.grpc_python_out_directory]):
- pb2_grpc = importlib.import_module(self.pb2_grpc_import)
- server = grpc.server(
- futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
- pb2_grpc.add_TestServiceServicer_to_server(
- DummySplitServicer(
- pb2.Request, pb2.Response), server)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- stub = pb2_grpc.TestServiceStub(channel)
- request = pb2.Request()
- expected_response = pb2.Response()
- response = stub.Call(request)
- self.assertEqual(expected_response, response)
+ def testImportAttributes(self):
+ with _system_path([self.python_out_directory]):
+ pb2 = importlib.import_module(self.pb2_import)
+ pb2.Request
+ pb2.Response
+ if self.should_find_services_in_pb2:
+ pb2.TestServiceServicer
+ else:
+ with self.assertRaises(AttributeError):
+ pb2.TestServiceServicer
+
+ with _system_path([self.grpc_python_out_directory]):
+ pb2_grpc = importlib.import_module(self.pb2_grpc_import)
+ pb2_grpc.TestServiceServicer
+ with self.assertRaises(AttributeError):
+ pb2_grpc.Request
+ with self.assertRaises(AttributeError):
+ pb2_grpc.Response
+
+ def testCall(self):
+ with _system_path([self.python_out_directory]):
+ pb2 = importlib.import_module(self.pb2_import)
+ with _system_path([self.grpc_python_out_directory]):
+ pb2_grpc = importlib.import_module(self.pb2_grpc_import)
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=test_constants.POOL_SIZE))
+ pb2_grpc.add_TestServiceServicer_to_server(
+ DummySplitServicer(pb2.Request, pb2.Response), server)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ stub = pb2_grpc.TestServiceStub(channel)
+ request = pb2.Request()
+ expected_response = pb2.Response()
+ response = stub.Call(request)
+ self.assertEqual(expected_response, response)
class SameSeparateTest(unittest.TestCase, SeparateTestMixin):
- def setUp(self):
- same_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing', 'same.proto')
- self.directory = tempfile.mkdtemp(suffix='same_separate', dir='.')
- self.proto_directory = os.path.join(self.directory, 'proto_path')
- self.python_out_directory = os.path.join(self.directory, 'python_out')
- self.grpc_python_out_directory = os.path.join(self.directory, 'grpc_python_out')
- os.makedirs(self.proto_directory)
- os.makedirs(self.python_out_directory)
- os.makedirs(self.grpc_python_out_directory)
- same_proto_file = os.path.join(self.proto_directory, 'same_separate.proto')
- open(same_proto_file, 'wb').write(same_proto_contents)
- protoc_result = protoc.main([
- '',
- '--proto_path={}'.format(self.proto_directory),
- '--python_out={}'.format(self.python_out_directory),
- '--grpc_python_out=grpc_2_0:{}'.format(self.grpc_python_out_directory),
- same_proto_file,
- ])
- if protoc_result != 0:
- raise Exception("unexpected protoc error")
- open(os.path.join(self.grpc_python_out_directory, '__init__.py'), 'w').write('')
- open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
- self.pb2_import = 'same_separate_pb2'
- self.pb2_grpc_import = 'same_separate_pb2_grpc'
- self.should_find_services_in_pb2 = False
-
- def tearDown(self):
- shutil.rmtree(self.directory)
+ def setUp(self):
+ same_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing', 'same.proto')
+ self.directory = tempfile.mkdtemp(suffix='same_separate', dir='.')
+ self.proto_directory = os.path.join(self.directory, 'proto_path')
+ self.python_out_directory = os.path.join(self.directory, 'python_out')
+ self.grpc_python_out_directory = os.path.join(self.directory,
+ 'grpc_python_out')
+ os.makedirs(self.proto_directory)
+ os.makedirs(self.python_out_directory)
+ os.makedirs(self.grpc_python_out_directory)
+ same_proto_file = os.path.join(self.proto_directory,
+ 'same_separate.proto')
+ open(same_proto_file, 'wb').write(same_proto_contents)
+ protoc_result = protoc.main([
+ '',
+ '--proto_path={}'.format(self.proto_directory),
+ '--python_out={}'.format(self.python_out_directory),
+ '--grpc_python_out=grpc_2_0:{}'.format(
+ self.grpc_python_out_directory),
+ same_proto_file,
+ ])
+ if protoc_result != 0:
+ raise Exception("unexpected protoc error")
+ open(os.path.join(self.grpc_python_out_directory, '__init__.py'),
+ 'w').write('')
+ open(os.path.join(self.python_out_directory, '__init__.py'),
+ 'w').write('')
+ self.pb2_import = 'same_separate_pb2'
+ self.pb2_grpc_import = 'same_separate_pb2_grpc'
+ self.should_find_services_in_pb2 = False
+
+ def tearDown(self):
+ shutil.rmtree(self.directory)
class SameCommonTest(unittest.TestCase, CommonTestMixin):
- def setUp(self):
- same_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing', 'same.proto')
- self.directory = tempfile.mkdtemp(suffix='same_common', dir='.')
- self.proto_directory = os.path.join(self.directory, 'proto_path')
- self.python_out_directory = os.path.join(self.directory, 'python_out')
- self.grpc_python_out_directory = self.python_out_directory
- os.makedirs(self.proto_directory)
- os.makedirs(self.python_out_directory)
- same_proto_file = os.path.join(self.proto_directory, 'same_common.proto')
- open(same_proto_file, 'wb').write(same_proto_contents)
- protoc_result = protoc.main([
- '',
- '--proto_path={}'.format(self.proto_directory),
- '--python_out={}'.format(self.python_out_directory),
- '--grpc_python_out={}'.format(self.grpc_python_out_directory),
- same_proto_file,
- ])
- if protoc_result != 0:
- raise Exception("unexpected protoc error")
- open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
- self.pb2_import = 'same_common_pb2'
- self.pb2_grpc_import = 'same_common_pb2_grpc'
- self.should_find_services_in_pb2 = True
-
- def tearDown(self):
- shutil.rmtree(self.directory)
+ def setUp(self):
+ same_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing', 'same.proto')
+ self.directory = tempfile.mkdtemp(suffix='same_common', dir='.')
+ self.proto_directory = os.path.join(self.directory, 'proto_path')
+ self.python_out_directory = os.path.join(self.directory, 'python_out')
+ self.grpc_python_out_directory = self.python_out_directory
+ os.makedirs(self.proto_directory)
+ os.makedirs(self.python_out_directory)
+ same_proto_file = os.path.join(self.proto_directory,
+ 'same_common.proto')
+ open(same_proto_file, 'wb').write(same_proto_contents)
+ protoc_result = protoc.main([
+ '',
+ '--proto_path={}'.format(self.proto_directory),
+ '--python_out={}'.format(self.python_out_directory),
+ '--grpc_python_out={}'.format(self.grpc_python_out_directory),
+ same_proto_file,
+ ])
+ if protoc_result != 0:
+ raise Exception("unexpected protoc error")
+ open(os.path.join(self.python_out_directory, '__init__.py'),
+ 'w').write('')
+ self.pb2_import = 'same_common_pb2'
+ self.pb2_grpc_import = 'same_common_pb2_grpc'
+ self.should_find_services_in_pb2 = True
+
+ def tearDown(self):
+ shutil.rmtree(self.directory)
class SplitCommonTest(unittest.TestCase, CommonTestMixin):
- def setUp(self):
- services_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing.split_services',
- 'services.proto')
- messages_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing.split_messages',
- 'messages.proto')
- self.directory = tempfile.mkdtemp(suffix='split_common', dir='.')
- self.proto_directory = os.path.join(self.directory, 'proto_path')
- self.python_out_directory = os.path.join(self.directory, 'python_out')
- self.grpc_python_out_directory = self.python_out_directory
- os.makedirs(self.proto_directory)
- os.makedirs(self.python_out_directory)
- services_proto_file = os.path.join(self.proto_directory,
- 'split_common_services.proto')
- messages_proto_file = os.path.join(self.proto_directory,
- 'split_common_messages.proto')
- open(services_proto_file, 'wb').write(services_proto_contents.replace(
- _MESSAGES_IMPORT,
- b'import "split_common_messages.proto";'
- ))
- open(messages_proto_file, 'wb').write(messages_proto_contents)
- protoc_result = protoc.main([
- '',
- '--proto_path={}'.format(self.proto_directory),
- '--python_out={}'.format(self.python_out_directory),
- '--grpc_python_out={}'.format(self.grpc_python_out_directory),
- services_proto_file,
- messages_proto_file,
- ])
- if protoc_result != 0:
- raise Exception("unexpected protoc error")
- open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
- self.pb2_import = 'split_common_messages_pb2'
- self.pb2_grpc_import = 'split_common_services_pb2_grpc'
- self.should_find_services_in_pb2 = False
-
- def tearDown(self):
- shutil.rmtree(self.directory)
+ def setUp(self):
+ services_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing.split_services',
+ 'services.proto')
+ messages_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing.split_messages',
+ 'messages.proto')
+ self.directory = tempfile.mkdtemp(suffix='split_common', dir='.')
+ self.proto_directory = os.path.join(self.directory, 'proto_path')
+ self.python_out_directory = os.path.join(self.directory, 'python_out')
+ self.grpc_python_out_directory = self.python_out_directory
+ os.makedirs(self.proto_directory)
+ os.makedirs(self.python_out_directory)
+ services_proto_file = os.path.join(self.proto_directory,
+ 'split_common_services.proto')
+ messages_proto_file = os.path.join(self.proto_directory,
+ 'split_common_messages.proto')
+ open(services_proto_file, 'wb').write(
+ services_proto_contents.replace(
+ _MESSAGES_IMPORT, b'import "split_common_messages.proto";'))
+ open(messages_proto_file, 'wb').write(messages_proto_contents)
+ protoc_result = protoc.main([
+ '',
+ '--proto_path={}'.format(self.proto_directory),
+ '--python_out={}'.format(self.python_out_directory),
+ '--grpc_python_out={}'.format(self.grpc_python_out_directory),
+ services_proto_file,
+ messages_proto_file,
+ ])
+ if protoc_result != 0:
+ raise Exception("unexpected protoc error")
+ open(os.path.join(self.python_out_directory, '__init__.py'),
+ 'w').write('')
+ self.pb2_import = 'split_common_messages_pb2'
+ self.pb2_grpc_import = 'split_common_services_pb2_grpc'
+ self.should_find_services_in_pb2 = False
+
+ def tearDown(self):
+ shutil.rmtree(self.directory)
class SplitSeparateTest(unittest.TestCase, SeparateTestMixin):
- def setUp(self):
- services_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing.split_services',
- 'services.proto')
- messages_proto_contents = pkgutil.get_data(
- 'tests.protoc_plugin.protos.invocation_testing.split_messages',
- 'messages.proto')
- self.directory = tempfile.mkdtemp(suffix='split_separate', dir='.')
- self.proto_directory = os.path.join(self.directory, 'proto_path')
- self.python_out_directory = os.path.join(self.directory, 'python_out')
- self.grpc_python_out_directory = os.path.join(self.directory, 'grpc_python_out')
- os.makedirs(self.proto_directory)
- os.makedirs(self.python_out_directory)
- os.makedirs(self.grpc_python_out_directory)
- services_proto_file = os.path.join(self.proto_directory,
- 'split_separate_services.proto')
- messages_proto_file = os.path.join(self.proto_directory,
- 'split_separate_messages.proto')
- open(services_proto_file, 'wb').write(services_proto_contents.replace(
- _MESSAGES_IMPORT,
- b'import "split_separate_messages.proto";'
- ))
- open(messages_proto_file, 'wb').write(messages_proto_contents)
- protoc_result = protoc.main([
- '',
- '--proto_path={}'.format(self.proto_directory),
- '--python_out={}'.format(self.python_out_directory),
- '--grpc_python_out=grpc_2_0:{}'.format(self.grpc_python_out_directory),
- services_proto_file,
- messages_proto_file,
- ])
- if protoc_result != 0:
- raise Exception("unexpected protoc error")
- open(os.path.join(self.python_out_directory, '__init__.py'), 'w').write('')
- self.pb2_import = 'split_separate_messages_pb2'
- self.pb2_grpc_import = 'split_separate_services_pb2_grpc'
- self.should_find_services_in_pb2 = False
-
- def tearDown(self):
- shutil.rmtree(self.directory)
+ def setUp(self):
+ services_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing.split_services',
+ 'services.proto')
+ messages_proto_contents = pkgutil.get_data(
+ 'tests.protoc_plugin.protos.invocation_testing.split_messages',
+ 'messages.proto')
+ self.directory = tempfile.mkdtemp(suffix='split_separate', dir='.')
+ self.proto_directory = os.path.join(self.directory, 'proto_path')
+ self.python_out_directory = os.path.join(self.directory, 'python_out')
+ self.grpc_python_out_directory = os.path.join(self.directory,
+ 'grpc_python_out')
+ os.makedirs(self.proto_directory)
+ os.makedirs(self.python_out_directory)
+ os.makedirs(self.grpc_python_out_directory)
+ services_proto_file = os.path.join(self.proto_directory,
+ 'split_separate_services.proto')
+ messages_proto_file = os.path.join(self.proto_directory,
+ 'split_separate_messages.proto')
+ open(services_proto_file, 'wb').write(
+ services_proto_contents.replace(
+ _MESSAGES_IMPORT, b'import "split_separate_messages.proto";'))
+ open(messages_proto_file, 'wb').write(messages_proto_contents)
+ protoc_result = protoc.main([
+ '',
+ '--proto_path={}'.format(self.proto_directory),
+ '--python_out={}'.format(self.python_out_directory),
+ '--grpc_python_out=grpc_2_0:{}'.format(
+ self.grpc_python_out_directory),
+ services_proto_file,
+ messages_proto_file,
+ ])
+ if protoc_result != 0:
+ raise Exception("unexpected protoc error")
+ open(os.path.join(self.python_out_directory, '__init__.py'),
+ 'w').write('')
+ self.pb2_import = 'split_separate_messages_pb2'
+ self.pb2_grpc_import = 'split_separate_services_pb2_grpc'
+ self.should_find_services_in_pb2 = False
+
+ def tearDown(self):
+ shutil.rmtree(self.directory)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
index 1eba9c9354..f64f4e962b 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/beta_python_plugin_test.py
@@ -64,84 +64,84 @@ STUB_FACTORY_IDENTIFIER = 'beta_create_TestService_stub'
class _ServicerMethods(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._paused = False
- self._fail = False
-
- @contextlib.contextmanager
- def pause(self): # pylint: disable=invalid-name
- with self._condition:
- self._paused = True
- yield
- with self._condition:
- self._paused = False
- self._condition.notify_all()
-
- @contextlib.contextmanager
- def fail(self): # pylint: disable=invalid-name
- with self._condition:
- self._fail = True
- yield
- with self._condition:
- self._fail = False
-
- def _control(self): # pylint: disable=invalid-name
- with self._condition:
- if self._fail:
- raise ValueError()
- while self._paused:
- self._condition.wait()
-
- def UnaryCall(self, request, unused_rpc_context):
- response = response_pb2.SimpleResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * request.response_size
- self._control()
- return response
-
- def StreamingOutputCall(self, request, unused_rpc_context):
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
- self._control()
- yield response
-
- def StreamingInputCall(self, request_iter, unused_rpc_context):
- response = response_pb2.StreamingInputCallResponse()
- aggregated_payload_size = 0
- for request in request_iter:
- aggregated_payload_size += len(request.payload.payload_compressable)
- response.aggregated_payload_size = aggregated_payload_size
- self._control()
- return response
-
- def FullDuplexCall(self, request_iter, unused_rpc_context):
- for request in request_iter:
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
- response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
- self._control()
- yield response
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._paused = False
+ self._fail = False
+
+ @contextlib.contextmanager
+ def pause(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._paused = True
+ yield
+ with self._condition:
+ self._paused = False
+ self._condition.notify_all()
- def HalfDuplexCall(self, request_iter, unused_rpc_context):
- responses = []
- for request in request_iter:
- for parameter in request.response_parameters:
- response = response_pb2.StreamingOutputCallResponse()
+ @contextlib.contextmanager
+ def fail(self): # pylint: disable=invalid-name
+ with self._condition:
+ self._fail = True
+ yield
+ with self._condition:
+ self._fail = False
+
+ def _control(self): # pylint: disable=invalid-name
+ with self._condition:
+ if self._fail:
+ raise ValueError()
+ while self._paused:
+ self._condition.wait()
+
+ def UnaryCall(self, request, unused_rpc_context):
+ response = response_pb2.SimpleResponse()
response.payload.payload_type = payload_pb2.COMPRESSABLE
- response.payload.payload_compressable = 'a' * parameter.size
+ response.payload.payload_compressable = 'a' * request.response_size
+ self._control()
+ return response
+
+ def StreamingOutputCall(self, request, unused_rpc_context):
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def StreamingInputCall(self, request_iter, unused_rpc_context):
+ response = response_pb2.StreamingInputCallResponse()
+ aggregated_payload_size = 0
+ for request in request_iter:
+ aggregated_payload_size += len(request.payload.payload_compressable)
+ response.aggregated_payload_size = aggregated_payload_size
self._control()
- responses.append(response)
- for response in responses:
- yield response
+ return response
+
+ def FullDuplexCall(self, request_iter, unused_rpc_context):
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ yield response
+
+ def HalfDuplexCall(self, request_iter, unused_rpc_context):
+ responses = []
+ for request in request_iter:
+ for parameter in request.response_parameters:
+ response = response_pb2.StreamingOutputCallResponse()
+ response.payload.payload_type = payload_pb2.COMPRESSABLE
+ response.payload.payload_compressable = 'a' * parameter.size
+ self._control()
+ responses.append(response)
+ for response in responses:
+ yield response
@contextlib.contextmanager
def _CreateService():
- """Provides a servicer backend and a stub.
+ """Provides a servicer backend and a stub.
The servicer is just the implementation of the actual servicer passed to the
face player of the python RPC implementation; the two are detached.
@@ -151,38 +151,38 @@ def _CreateService():
the service bound to the stub and and stub is the stub on which to invoke
RPCs.
"""
- servicer_methods = _ServicerMethods()
+ servicer_methods = _ServicerMethods()
- class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
- def UnaryCall(self, request, context):
- return servicer_methods.UnaryCall(request, context)
+ def UnaryCall(self, request, context):
+ return servicer_methods.UnaryCall(request, context)
- def StreamingOutputCall(self, request, context):
- return servicer_methods.StreamingOutputCall(request, context)
+ def StreamingOutputCall(self, request, context):
+ return servicer_methods.StreamingOutputCall(request, context)
- def StreamingInputCall(self, request_iter, context):
- return servicer_methods.StreamingInputCall(request_iter, context)
+ def StreamingInputCall(self, request_iter, context):
+ return servicer_methods.StreamingInputCall(request_iter, context)
- def FullDuplexCall(self, request_iter, context):
- return servicer_methods.FullDuplexCall(request_iter, context)
+ def FullDuplexCall(self, request_iter, context):
+ return servicer_methods.FullDuplexCall(request_iter, context)
- def HalfDuplexCall(self, request_iter, context):
- return servicer_methods.HalfDuplexCall(request_iter, context)
+ def HalfDuplexCall(self, request_iter, context):
+ return servicer_methods.HalfDuplexCall(request_iter, context)
- servicer = Servicer()
- server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = implementations.insecure_channel('localhost', port)
- stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel)
- yield (servicer_methods, stub)
- server.stop(0)
+ servicer = Servicer()
+ server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = implementations.insecure_channel('localhost', port)
+ stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel)
+ yield (servicer_methods, stub)
+ server.stop(0)
@contextlib.contextmanager
def _CreateIncompleteService():
- """Provides a servicer backend that fails to implement methods and its stub.
+ """Provides a servicer backend that fails to implement methods and its stub.
The servicer is just the implementation of the actual servicer passed to the
face player of the python RPC implementation; the two are detached.
@@ -194,297 +194,297 @@ def _CreateIncompleteService():
RPCs.
"""
- class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
- pass
+ class Servicer(getattr(service_pb2, SERVICER_IDENTIFIER)):
+ pass
- servicer = Servicer()
- server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = implementations.insecure_channel('localhost', port)
- stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel)
- yield None, stub
- server.stop(0)
+ servicer = Servicer()
+ server = getattr(service_pb2, SERVER_FACTORY_IDENTIFIER)(servicer)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = implementations.insecure_channel('localhost', port)
+ stub = getattr(service_pb2, STUB_FACTORY_IDENTIFIER)(channel)
+ yield None, stub
+ server.stop(0)
def _streaming_input_request_iterator():
- for _ in range(3):
- request = request_pb2.StreamingInputCallRequest()
- request.payload.payload_type = payload_pb2.COMPRESSABLE
- request.payload.payload_compressable = 'a'
- yield request
+ for _ in range(3):
+ request = request_pb2.StreamingInputCallRequest()
+ request.payload.payload_type = payload_pb2.COMPRESSABLE
+ request.payload.payload_compressable = 'a'
+ yield request
def _streaming_output_request():
- request = request_pb2.StreamingOutputCallRequest()
- sizes = [1, 2, 3]
- request.response_parameters.add(size=sizes[0], interval_us=0)
- request.response_parameters.add(size=sizes[1], interval_us=0)
- request.response_parameters.add(size=sizes[2], interval_us=0)
- return request
+ request = request_pb2.StreamingOutputCallRequest()
+ sizes = [1, 2, 3]
+ request.response_parameters.add(size=sizes[0], interval_us=0)
+ request.response_parameters.add(size=sizes[1], interval_us=0)
+ request.response_parameters.add(size=sizes[2], interval_us=0)
+ return request
def _full_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=2, interval_us=0)
- request.response_parameters.add(size=3, interval_us=0)
- yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
class PythonPluginTest(unittest.TestCase):
- """Test case for the gRPC Python protoc-plugin.
+ """Test case for the gRPC Python protoc-plugin.
While reading these tests, remember that the futures API
(`stub.method.future()`) only gives futures for the *response-unary*
methods and does not exist for response-streaming methods.
"""
- def testImportAttributes(self):
- # check that we can access the generated module and its members.
- self.assertIsNotNone(
- getattr(service_pb2, SERVICER_IDENTIFIER, None))
- self.assertIsNotNone(
- getattr(service_pb2, STUB_IDENTIFIER, None))
- self.assertIsNotNone(
- getattr(service_pb2, SERVER_FACTORY_IDENTIFIER, None))
- self.assertIsNotNone(
- getattr(service_pb2, STUB_FACTORY_IDENTIFIER, None))
-
- def testUpDown(self):
- with _CreateService():
- request_pb2.SimpleRequest(response_size=13)
-
- def testIncompleteServicer(self):
- with _CreateIncompleteService() as (_, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- try:
- stub.UnaryCall(request, test_constants.LONG_TIMEOUT)
- except face.AbortionError as error:
- self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED, error.code)
-
- def testUnaryCall(self):
- with _CreateService() as (methods, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- response = stub.UnaryCall(request, test_constants.LONG_TIMEOUT)
- expected_response = methods.UnaryCall(request, 'not a real context!')
- self.assertEqual(expected_response, response)
-
- def testUnaryCallFuture(self):
- with _CreateService() as (methods, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- # Check that the call does not block waiting for the server to respond.
- with methods.pause():
- response_future = stub.UnaryCall.future(
- request, test_constants.LONG_TIMEOUT)
- response = response_future.result()
- expected_response = methods.UnaryCall(request, 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testUnaryCallFutureExpired(self):
- with _CreateService() as (methods, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- with methods.pause():
- response_future = stub.UnaryCall.future(
- request, test_constants.SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
-
- def testUnaryCallFutureCancelled(self):
- with _CreateService() as (methods, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- with methods.pause():
- response_future = stub.UnaryCall.future(request, 1)
- response_future.cancel()
- self.assertTrue(response_future.cancelled())
-
- def testUnaryCallFutureFailed(self):
- with _CreateService() as (methods, stub):
- request = request_pb2.SimpleRequest(response_size=13)
- with methods.fail():
- response_future = stub.UnaryCall.future(
- request, test_constants.LONG_TIMEOUT)
- self.assertIsNotNone(response_future.exception())
-
- def testStreamingOutputCall(self):
- with _CreateService() as (methods, stub):
- request = _streaming_output_request()
- responses = stub.StreamingOutputCall(
- request, test_constants.LONG_TIMEOUT)
- expected_responses = methods.StreamingOutputCall(
- request, 'not a real RpcContext!')
- for expected_response, response in moves.zip_longest(
- expected_responses, responses):
+ def testImportAttributes(self):
+ # check that we can access the generated module and its members.
+ self.assertIsNotNone(getattr(service_pb2, SERVICER_IDENTIFIER, None))
+ self.assertIsNotNone(getattr(service_pb2, STUB_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(service_pb2, SERVER_FACTORY_IDENTIFIER, None))
+ self.assertIsNotNone(
+ getattr(service_pb2, STUB_FACTORY_IDENTIFIER, None))
+
+ def testUpDown(self):
+ with _CreateService():
+ request_pb2.SimpleRequest(response_size=13)
+
+ def testIncompleteServicer(self):
+ with _CreateIncompleteService() as (_, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ try:
+ stub.UnaryCall(request, test_constants.LONG_TIMEOUT)
+ except face.AbortionError as error:
+ self.assertEqual(interfaces.StatusCode.UNIMPLEMENTED,
+ error.code)
+
+ def testUnaryCall(self):
+ with _CreateService() as (methods, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ response = stub.UnaryCall(request, test_constants.LONG_TIMEOUT)
+ expected_response = methods.UnaryCall(request, 'not a real context!')
self.assertEqual(expected_response, response)
- def testStreamingOutputCallExpired(self):
- with _CreateService() as (methods, stub):
- request = _streaming_output_request()
- with methods.pause():
- responses = stub.StreamingOutputCall(
- request, test_constants.SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- list(responses)
-
- def testStreamingOutputCallCancelled(self):
- with _CreateService() as (methods, stub):
- request = _streaming_output_request()
- responses = stub.StreamingOutputCall(
- request, test_constants.LONG_TIMEOUT)
- next(responses)
- responses.cancel()
- with self.assertRaises(face.CancellationError):
- next(responses)
-
- def testStreamingOutputCallFailed(self):
- with _CreateService() as (methods, stub):
- request = _streaming_output_request()
- with methods.fail():
- responses = stub.StreamingOutputCall(request, 1)
- self.assertIsNotNone(responses)
- with self.assertRaises(face.RemoteError):
- next(responses)
-
- def testStreamingInputCall(self):
- with _CreateService() as (methods, stub):
- response = stub.StreamingInputCall(
- _streaming_input_request_iterator(),
- test_constants.LONG_TIMEOUT)
- expected_response = methods.StreamingInputCall(
- _streaming_input_request_iterator(),
- 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testStreamingInputCallFuture(self):
- with _CreateService() as (methods, stub):
- with methods.pause():
- response_future = stub.StreamingInputCall.future(
- _streaming_input_request_iterator(),
- test_constants.LONG_TIMEOUT)
- response = response_future.result()
- expected_response = methods.StreamingInputCall(
- _streaming_input_request_iterator(),
- 'not a real RpcContext!')
- self.assertEqual(expected_response, response)
-
- def testStreamingInputCallFutureExpired(self):
- with _CreateService() as (methods, stub):
- with methods.pause():
- response_future = stub.StreamingInputCall.future(
- _streaming_input_request_iterator(),
- test_constants.SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
- self.assertIsInstance(
- response_future.exception(), face.ExpirationError)
-
- def testStreamingInputCallFutureCancelled(self):
- with _CreateService() as (methods, stub):
- with methods.pause():
- response_future = stub.StreamingInputCall.future(
- _streaming_input_request_iterator(),
- test_constants.LONG_TIMEOUT)
- response_future.cancel()
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(future.CancelledError):
- response_future.result()
-
- def testStreamingInputCallFutureFailed(self):
- with _CreateService() as (methods, stub):
- with methods.fail():
- response_future = stub.StreamingInputCall.future(
- _streaming_input_request_iterator(),
- test_constants.LONG_TIMEOUT)
- self.assertIsNotNone(response_future.exception())
-
- def testFullDuplexCall(self):
- with _CreateService() as (methods, stub):
- responses = stub.FullDuplexCall(
- _full_duplex_request_iterator(),
- test_constants.LONG_TIMEOUT)
- expected_responses = methods.FullDuplexCall(
- _full_duplex_request_iterator(),
- 'not a real RpcContext!')
- for expected_response, response in moves.zip_longest(
- expected_responses, responses):
+ def testUnaryCallFuture(self):
+ with _CreateService() as (methods, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ # Check that the call does not block waiting for the server to respond.
+ with methods.pause():
+ response_future = stub.UnaryCall.future(
+ request, test_constants.LONG_TIMEOUT)
+ response = response_future.result()
+ expected_response = methods.UnaryCall(request, 'not a real RpcContext!')
self.assertEqual(expected_response, response)
- def testFullDuplexCallExpired(self):
- request_iterator = _full_duplex_request_iterator()
- with _CreateService() as (methods, stub):
- with methods.pause():
- responses = stub.FullDuplexCall(
- request_iterator, test_constants.SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- list(responses)
-
- def testFullDuplexCallCancelled(self):
- with _CreateService() as (methods, stub):
- request_iterator = _full_duplex_request_iterator()
- responses = stub.FullDuplexCall(
- request_iterator, test_constants.LONG_TIMEOUT)
- next(responses)
- responses.cancel()
- with self.assertRaises(face.CancellationError):
- next(responses)
-
- def testFullDuplexCallFailed(self):
- request_iterator = _full_duplex_request_iterator()
- with _CreateService() as (methods, stub):
- with methods.fail():
- responses = stub.FullDuplexCall(
- request_iterator, test_constants.LONG_TIMEOUT)
- self.assertIsNotNone(responses)
- with self.assertRaises(face.RemoteError):
- next(responses)
-
- def testHalfDuplexCall(self):
- with _CreateService() as (methods, stub):
- def half_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=2, interval_us=0)
- request.response_parameters.add(size=3, interval_us=0)
- yield request
- responses = stub.HalfDuplexCall(
- half_duplex_request_iterator(), test_constants.LONG_TIMEOUT)
- expected_responses = methods.HalfDuplexCall(
- half_duplex_request_iterator(), 'not a real RpcContext!')
- for check in moves.zip_longest(expected_responses, responses):
- expected_response, response = check
+ def testUnaryCallFutureExpired(self):
+ with _CreateService() as (methods, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ with methods.pause():
+ response_future = stub.UnaryCall.future(
+ request, test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+
+ def testUnaryCallFutureCancelled(self):
+ with _CreateService() as (methods, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ with methods.pause():
+ response_future = stub.UnaryCall.future(request, 1)
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+
+ def testUnaryCallFutureFailed(self):
+ with _CreateService() as (methods, stub):
+ request = request_pb2.SimpleRequest(response_size=13)
+ with methods.fail():
+ response_future = stub.UnaryCall.future(
+ request, test_constants.LONG_TIMEOUT)
+ self.assertIsNotNone(response_future.exception())
+
+ def testStreamingOutputCall(self):
+ with _CreateService() as (methods, stub):
+ request = _streaming_output_request()
+ responses = stub.StreamingOutputCall(request,
+ test_constants.LONG_TIMEOUT)
+ expected_responses = methods.StreamingOutputCall(
+ request, 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testStreamingOutputCallExpired(self):
+ with _CreateService() as (methods, stub):
+ request = _streaming_output_request()
+ with methods.pause():
+ responses = stub.StreamingOutputCall(
+ request, test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ list(responses)
+
+ def testStreamingOutputCallCancelled(self):
+ with _CreateService() as (methods, stub):
+ request = _streaming_output_request()
+ responses = stub.StreamingOutputCall(request,
+ test_constants.LONG_TIMEOUT)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(face.CancellationError):
+ next(responses)
+
+ def testStreamingOutputCallFailed(self):
+ with _CreateService() as (methods, stub):
+ request = _streaming_output_request()
+ with methods.fail():
+ responses = stub.StreamingOutputCall(request, 1)
+ self.assertIsNotNone(responses)
+ with self.assertRaises(face.RemoteError):
+ next(responses)
+
+ def testStreamingInputCall(self):
+ with _CreateService() as (methods, stub):
+ response = stub.StreamingInputCall(
+ _streaming_input_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ expected_response = methods.StreamingInputCall(
+ _streaming_input_request_iterator(), 'not a real RpcContext!')
self.assertEqual(expected_response, response)
- def testHalfDuplexCallWedged(self):
- condition = threading.Condition()
- wait_cell = [False]
- @contextlib.contextmanager
- def wait(): # pylint: disable=invalid-name
- # Where's Python 3's 'nonlocal' statement when you need it?
- with condition:
- wait_cell[0] = True
- yield
- with condition:
- wait_cell[0] = False
- condition.notify_all()
- def half_duplex_request_iterator():
- request = request_pb2.StreamingOutputCallRequest()
- request.response_parameters.add(size=1, interval_us=0)
- yield request
- with condition:
- while wait_cell[0]:
- condition.wait()
- with _CreateService() as (methods, stub):
- with wait():
- responses = stub.HalfDuplexCall(
- half_duplex_request_iterator(), test_constants.SHORT_TIMEOUT)
- # half-duplex waits for the client to send all info
- with self.assertRaises(face.ExpirationError):
- next(responses)
+ def testStreamingInputCallFuture(self):
+ with _CreateService() as (methods, stub):
+ with methods.pause():
+ response_future = stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ response = response_future.result()
+ expected_response = methods.StreamingInputCall(
+ _streaming_input_request_iterator(), 'not a real RpcContext!')
+ self.assertEqual(expected_response, response)
+
+ def testStreamingInputCallFutureExpired(self):
+ with _CreateService() as (methods, stub):
+ with methods.pause():
+ response_future = stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(),
+ test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+ self.assertIsInstance(response_future.exception(),
+ face.ExpirationError)
+
+ def testStreamingInputCallFutureCancelled(self):
+ with _CreateService() as (methods, stub):
+ with methods.pause():
+ response_future = stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ response_future.cancel()
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(future.CancelledError):
+ response_future.result()
+
+ def testStreamingInputCallFutureFailed(self):
+ with _CreateService() as (methods, stub):
+ with methods.fail():
+ response_future = stub.StreamingInputCall.future(
+ _streaming_input_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ self.assertIsNotNone(response_future.exception())
+
+ def testFullDuplexCall(self):
+ with _CreateService() as (methods, stub):
+ responses = stub.FullDuplexCall(_full_duplex_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ expected_responses = methods.FullDuplexCall(
+ _full_duplex_request_iterator(), 'not a real RpcContext!')
+ for expected_response, response in moves.zip_longest(
+ expected_responses, responses):
+ self.assertEqual(expected_response, response)
+
+ def testFullDuplexCallExpired(self):
+ request_iterator = _full_duplex_request_iterator()
+ with _CreateService() as (methods, stub):
+ with methods.pause():
+ responses = stub.FullDuplexCall(request_iterator,
+ test_constants.SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ list(responses)
+
+ def testFullDuplexCallCancelled(self):
+ with _CreateService() as (methods, stub):
+ request_iterator = _full_duplex_request_iterator()
+ responses = stub.FullDuplexCall(request_iterator,
+ test_constants.LONG_TIMEOUT)
+ next(responses)
+ responses.cancel()
+ with self.assertRaises(face.CancellationError):
+ next(responses)
+
+ def testFullDuplexCallFailed(self):
+ request_iterator = _full_duplex_request_iterator()
+ with _CreateService() as (methods, stub):
+ with methods.fail():
+ responses = stub.FullDuplexCall(request_iterator,
+ test_constants.LONG_TIMEOUT)
+ self.assertIsNotNone(responses)
+ with self.assertRaises(face.RemoteError):
+ next(responses)
+
+ def testHalfDuplexCall(self):
+ with _CreateService() as (methods, stub):
+
+ def half_duplex_request_iterator():
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=2, interval_us=0)
+ request.response_parameters.add(size=3, interval_us=0)
+ yield request
+
+ responses = stub.HalfDuplexCall(half_duplex_request_iterator(),
+ test_constants.LONG_TIMEOUT)
+ expected_responses = methods.HalfDuplexCall(
+ half_duplex_request_iterator(), 'not a real RpcContext!')
+ for check in moves.zip_longest(expected_responses, responses):
+ expected_response, response = check
+ self.assertEqual(expected_response, response)
+
+ def testHalfDuplexCallWedged(self):
+ condition = threading.Condition()
+ wait_cell = [False]
+
+ @contextlib.contextmanager
+ def wait(): # pylint: disable=invalid-name
+ # Where's Python 3's 'nonlocal' statement when you need it?
+ with condition:
+ wait_cell[0] = True
+ yield
+ with condition:
+ wait_cell[0] = False
+ condition.notify_all()
+
+ def half_duplex_request_iterator():
+ request = request_pb2.StreamingOutputCallRequest()
+ request.response_parameters.add(size=1, interval_us=0)
+ yield request
+ with condition:
+ while wait_cell[0]:
+ condition.wait()
+
+ with _CreateService() as (methods, stub):
+ with wait():
+ responses = stub.HalfDuplexCall(half_duplex_request_iterator(),
+ test_constants.SHORT_TIMEOUT)
+ # half-duplex waits for the client to send all info
+ with self.assertRaises(face.ExpirationError):
+ next(responses)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py
index 2f88fa0412..100a624dc9 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py
index 2f88fa0412..100a624dc9 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_messages/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py
index 2f88fa0412..100a624dc9 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/invocation_testing/split_services/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/payload/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/requests/r/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/responses/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py b/src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py
+++ b/src/python/grpcio_tests/tests/protoc_plugin/protos/service/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_client.py b/src/python/grpcio_tests/tests/qps/benchmark_client.py
index 650e4756e7..2e8afc8e7f 100644
--- a/src/python/grpcio_tests/tests/qps/benchmark_client.py
+++ b/src/python/grpcio_tests/tests/qps/benchmark_client.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Defines test client behaviors (UNARY/STREAMING) (SYNC/ASYNC)."""
import abc
@@ -47,165 +46,168 @@ _TIMEOUT = 60 * 60 * 24
class GenericStub(object):
- def __init__(self, channel):
- self.UnaryCall = channel.unary_unary(
- '/grpc.testing.BenchmarkService/UnaryCall')
- self.StreamingCall = channel.stream_stream(
- '/grpc.testing.BenchmarkService/StreamingCall')
+ def __init__(self, channel):
+ self.UnaryCall = channel.unary_unary(
+ '/grpc.testing.BenchmarkService/UnaryCall')
+ self.StreamingCall = channel.stream_stream(
+ '/grpc.testing.BenchmarkService/StreamingCall')
class BenchmarkClient:
- """Benchmark client interface that exposes a non-blocking send_request()."""
-
- __metaclass__ = abc.ABCMeta
-
- def __init__(self, server, config, hist):
- # Create the stub
- if config.HasField('security_params'):
- creds = grpc.ssl_channel_credentials(resources.test_root_certificates())
- channel = test_common.test_secure_channel(
- server, creds, config.security_params.server_host_override)
- else:
- channel = grpc.insecure_channel(server)
-
- # waits for the channel to be ready before we start sending messages
- grpc.channel_ready_future(channel).result()
-
- if config.payload_config.WhichOneof('payload') == 'simple_params':
- self._generic = False
- self._stub = services_pb2.BenchmarkServiceStub(channel)
- payload = messages_pb2.Payload(
- body='\0' * config.payload_config.simple_params.req_size)
- self._request = messages_pb2.SimpleRequest(
- payload=payload,
- response_size=config.payload_config.simple_params.resp_size)
- else:
- self._generic = True
- self._stub = GenericStub(channel)
- self._request = '\0' * config.payload_config.bytebuf_params.req_size
-
- self._hist = hist
- self._response_callbacks = []
-
- def add_response_callback(self, callback):
- """callback will be invoked as callback(client, query_time)"""
- self._response_callbacks.append(callback)
-
- @abc.abstractmethod
- def send_request(self):
- """Non-blocking wrapper for a client's request operation."""
- raise NotImplementedError()
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def _handle_response(self, client, query_time):
- self._hist.add(query_time * 1e9) # Report times in nanoseconds
- for callback in self._response_callbacks:
- callback(client, query_time)
+ """Benchmark client interface that exposes a non-blocking send_request()."""
+
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, server, config, hist):
+ # Create the stub
+ if config.HasField('security_params'):
+ creds = grpc.ssl_channel_credentials(
+ resources.test_root_certificates())
+ channel = test_common.test_secure_channel(
+ server, creds, config.security_params.server_host_override)
+ else:
+ channel = grpc.insecure_channel(server)
+
+ # waits for the channel to be ready before we start sending messages
+ grpc.channel_ready_future(channel).result()
+
+ if config.payload_config.WhichOneof('payload') == 'simple_params':
+ self._generic = False
+ self._stub = services_pb2.BenchmarkServiceStub(channel)
+ payload = messages_pb2.Payload(
+ body='\0' * config.payload_config.simple_params.req_size)
+ self._request = messages_pb2.SimpleRequest(
+ payload=payload,
+ response_size=config.payload_config.simple_params.resp_size)
+ else:
+ self._generic = True
+ self._stub = GenericStub(channel)
+ self._request = '\0' * config.payload_config.bytebuf_params.req_size
+
+ self._hist = hist
+ self._response_callbacks = []
+
+ def add_response_callback(self, callback):
+ """callback will be invoked as callback(client, query_time)"""
+ self._response_callbacks.append(callback)
+
+ @abc.abstractmethod
+ def send_request(self):
+ """Non-blocking wrapper for a client's request operation."""
+ raise NotImplementedError()
+
+ def start(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def _handle_response(self, client, query_time):
+ self._hist.add(query_time * 1e9) # Report times in nanoseconds
+ for callback in self._response_callbacks:
+ callback(client, query_time)
class UnarySyncBenchmarkClient(BenchmarkClient):
- def __init__(self, server, config, hist):
- super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
- self._pool = futures.ThreadPoolExecutor(
- max_workers=config.outstanding_rpcs_per_channel)
+ def __init__(self, server, config, hist):
+ super(UnarySyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel)
- def send_request(self):
- # Send requests in seperate threads to support multiple outstanding rpcs
- # (See src/proto/grpc/testing/control.proto)
- self._pool.submit(self._dispatch_request)
+ def send_request(self):
+ # Send requests in seperate threads to support multiple outstanding rpcs
+ # (See src/proto/grpc/testing/control.proto)
+ self._pool.submit(self._dispatch_request)
- def stop(self):
- self._pool.shutdown(wait=True)
- self._stub = None
+ def stop(self):
+ self._pool.shutdown(wait=True)
+ self._stub = None
- def _dispatch_request(self):
- start_time = time.time()
- self._stub.UnaryCall(self._request, _TIMEOUT)
- end_time = time.time()
- self._handle_response(self, end_time - start_time)
+ def _dispatch_request(self):
+ start_time = time.time()
+ self._stub.UnaryCall(self._request, _TIMEOUT)
+ end_time = time.time()
+ self._handle_response(self, end_time - start_time)
class UnaryAsyncBenchmarkClient(BenchmarkClient):
- def send_request(self):
- # Use the Future callback api to support multiple outstanding rpcs
- start_time = time.time()
- response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
- response_future.add_done_callback(
- lambda resp: self._response_received(start_time, resp))
+ def send_request(self):
+ # Use the Future callback api to support multiple outstanding rpcs
+ start_time = time.time()
+ response_future = self._stub.UnaryCall.future(self._request, _TIMEOUT)
+ response_future.add_done_callback(
+ lambda resp: self._response_received(start_time, resp))
- def _response_received(self, start_time, resp):
- resp.result()
- end_time = time.time()
- self._handle_response(self, end_time - start_time)
+ def _response_received(self, start_time, resp):
+ resp.result()
+ end_time = time.time()
+ self._handle_response(self, end_time - start_time)
- def stop(self):
- self._stub = None
+ def stop(self):
+ self._stub = None
class _SyncStream(object):
- def __init__(self, stub, generic, request, handle_response):
- self._stub = stub
- self._generic = generic
- self._request = request
- self._handle_response = handle_response
- self._is_streaming = False
- self._request_queue = queue.Queue()
- self._send_time_queue = queue.Queue()
-
- def send_request(self):
- self._send_time_queue.put(time.time())
- self._request_queue.put(self._request)
-
- def start(self):
- self._is_streaming = True
- response_stream = self._stub.StreamingCall(
- self._request_generator(), _TIMEOUT)
- for _ in response_stream:
- self._handle_response(
- self, time.time() - self._send_time_queue.get_nowait())
-
- def stop(self):
- self._is_streaming = False
-
- def _request_generator(self):
- while self._is_streaming:
- try:
- request = self._request_queue.get(block=True, timeout=1.0)
- yield request
- except queue.Empty:
- pass
+ def __init__(self, stub, generic, request, handle_response):
+ self._stub = stub
+ self._generic = generic
+ self._request = request
+ self._handle_response = handle_response
+ self._is_streaming = False
+ self._request_queue = queue.Queue()
+ self._send_time_queue = queue.Queue()
+
+ def send_request(self):
+ self._send_time_queue.put(time.time())
+ self._request_queue.put(self._request)
+
+ def start(self):
+ self._is_streaming = True
+ response_stream = self._stub.StreamingCall(self._request_generator(),
+ _TIMEOUT)
+ for _ in response_stream:
+ self._handle_response(
+ self, time.time() - self._send_time_queue.get_nowait())
+
+ def stop(self):
+ self._is_streaming = False
+
+ def _request_generator(self):
+ while self._is_streaming:
+ try:
+ request = self._request_queue.get(block=True, timeout=1.0)
+ yield request
+ except queue.Empty:
+ pass
class StreamingSyncBenchmarkClient(BenchmarkClient):
- def __init__(self, server, config, hist):
- super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
- self._pool = futures.ThreadPoolExecutor(
- max_workers=config.outstanding_rpcs_per_channel)
- self._streams = [_SyncStream(self._stub, self._generic,
- self._request, self._handle_response)
- for _ in xrange(config.outstanding_rpcs_per_channel)]
- self._curr_stream = 0
-
- def send_request(self):
- # Use a round_robin scheduler to determine what stream to send on
- self._streams[self._curr_stream].send_request()
- self._curr_stream = (self._curr_stream + 1) % len(self._streams)
-
- def start(self):
- for stream in self._streams:
- self._pool.submit(stream.start)
-
- def stop(self):
- for stream in self._streams:
- stream.stop()
- self._pool.shutdown(wait=True)
- self._stub = None
+ def __init__(self, server, config, hist):
+ super(StreamingSyncBenchmarkClient, self).__init__(server, config, hist)
+ self._pool = futures.ThreadPoolExecutor(
+ max_workers=config.outstanding_rpcs_per_channel)
+ self._streams = [
+ _SyncStream(self._stub, self._generic, self._request,
+ self._handle_response)
+ for _ in xrange(config.outstanding_rpcs_per_channel)
+ ]
+ self._curr_stream = 0
+
+ def send_request(self):
+ # Use a round_robin scheduler to determine what stream to send on
+ self._streams[self._curr_stream].send_request()
+ self._curr_stream = (self._curr_stream + 1) % len(self._streams)
+
+ def start(self):
+ for stream in self._streams:
+ self._pool.submit(stream.start)
+
+ def stop(self):
+ for stream in self._streams:
+ stream.stop()
+ self._pool.shutdown(wait=True)
+ self._stub = None
diff --git a/src/python/grpcio_tests/tests/qps/benchmark_server.py b/src/python/grpcio_tests/tests/qps/benchmark_server.py
index 2b76b810cd..423d03b804 100644
--- a/src/python/grpcio_tests/tests/qps/benchmark_server.py
+++ b/src/python/grpcio_tests/tests/qps/benchmark_server.py
@@ -32,27 +32,27 @@ from src.proto.grpc.testing import services_pb2
class BenchmarkServer(services_pb2.BenchmarkServiceServicer):
- """Synchronous Server implementation for the Benchmark service."""
+ """Synchronous Server implementation for the Benchmark service."""
- def UnaryCall(self, request, context):
- payload = messages_pb2.Payload(body='\0' * request.response_size)
- return messages_pb2.SimpleResponse(payload=payload)
+ def UnaryCall(self, request, context):
+ payload = messages_pb2.Payload(body='\0' * request.response_size)
+ return messages_pb2.SimpleResponse(payload=payload)
- def StreamingCall(self, request_iterator, context):
- for request in request_iterator:
- payload = messages_pb2.Payload(body='\0' * request.response_size)
- yield messages_pb2.SimpleResponse(payload=payload)
+ def StreamingCall(self, request_iterator, context):
+ for request in request_iterator:
+ payload = messages_pb2.Payload(body='\0' * request.response_size)
+ yield messages_pb2.SimpleResponse(payload=payload)
class GenericBenchmarkServer(services_pb2.BenchmarkServiceServicer):
- """Generic Server implementation for the Benchmark service."""
+ """Generic Server implementation for the Benchmark service."""
- def __init__(self, resp_size):
- self._response = '\0' * resp_size
+ def __init__(self, resp_size):
+ self._response = '\0' * resp_size
- def UnaryCall(self, request, context):
- return self._response
+ def UnaryCall(self, request, context):
+ return self._response
- def StreamingCall(self, request_iterator, context):
- for request in request_iterator:
- yield self._response
+ def StreamingCall(self, request_iterator, context):
+ for request in request_iterator:
+ yield self._response
diff --git a/src/python/grpcio_tests/tests/qps/client_runner.py b/src/python/grpcio_tests/tests/qps/client_runner.py
index 1fd58687ad..037092313c 100644
--- a/src/python/grpcio_tests/tests/qps/client_runner.py
+++ b/src/python/grpcio_tests/tests/qps/client_runner.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Defines behavior for WHEN clients send requests.
Each client exposes a non-blocking send_request() method that the
@@ -39,68 +38,68 @@ import time
class ClientRunner:
- """Abstract interface for sending requests from clients."""
+ """Abstract interface for sending requests from clients."""
- __metaclass__ = abc.ABCMeta
+ __metaclass__ = abc.ABCMeta
- def __init__(self, client):
- self._client = client
+ def __init__(self, client):
+ self._client = client
- @abc.abstractmethod
- def start(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def start(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def stop(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def stop(self):
+ raise NotImplementedError()
class OpenLoopClientRunner(ClientRunner):
- def __init__(self, client, interval_generator):
- super(OpenLoopClientRunner, self).__init__(client)
- self._is_running = False
- self._interval_generator = interval_generator
- self._dispatch_thread = threading.Thread(
- target=self._dispatch_requests, args=())
-
- def start(self):
- self._is_running = True
- self._client.start()
- self._dispatch_thread.start()
-
- def stop(self):
- self._is_running = False
- self._client.stop()
- self._dispatch_thread.join()
- self._client = None
-
- def _dispatch_requests(self):
- while self._is_running:
- self._client.send_request()
- time.sleep(next(self._interval_generator))
+ def __init__(self, client, interval_generator):
+ super(OpenLoopClientRunner, self).__init__(client)
+ self._is_running = False
+ self._interval_generator = interval_generator
+ self._dispatch_thread = threading.Thread(
+ target=self._dispatch_requests, args=())
+
+ def start(self):
+ self._is_running = True
+ self._client.start()
+ self._dispatch_thread.start()
+
+ def stop(self):
+ self._is_running = False
+ self._client.stop()
+ self._dispatch_thread.join()
+ self._client = None
+
+ def _dispatch_requests(self):
+ while self._is_running:
+ self._client.send_request()
+ time.sleep(next(self._interval_generator))
class ClosedLoopClientRunner(ClientRunner):
- def __init__(self, client, request_count):
- super(ClosedLoopClientRunner, self).__init__(client)
- self._is_running = False
- self._request_count = request_count
- # Send a new request on each response for closed loop
- self._client.add_response_callback(self._send_request)
-
- def start(self):
- self._is_running = True
- self._client.start()
- for _ in xrange(self._request_count):
- self._client.send_request()
-
- def stop(self):
- self._is_running = False
- self._client.stop()
- self._client = None
-
- def _send_request(self, client, response_time):
- if self._is_running:
- client.send_request()
+ def __init__(self, client, request_count):
+ super(ClosedLoopClientRunner, self).__init__(client)
+ self._is_running = False
+ self._request_count = request_count
+ # Send a new request on each response for closed loop
+ self._client.add_response_callback(self._send_request)
+
+ def start(self):
+ self._is_running = True
+ self._client.start()
+ for _ in xrange(self._request_count):
+ self._client.send_request()
+
+ def stop(self):
+ self._is_running = False
+ self._client.stop()
+ self._client = None
+
+ def _send_request(self, client, response_time):
+ if self._is_running:
+ client.send_request()
diff --git a/src/python/grpcio_tests/tests/qps/histogram.py b/src/python/grpcio_tests/tests/qps/histogram.py
index 9a7b5eb2ba..61040b6f3b 100644
--- a/src/python/grpcio_tests/tests/qps/histogram.py
+++ b/src/python/grpcio_tests/tests/qps/histogram.py
@@ -34,52 +34,52 @@ from src.proto.grpc.testing import stats_pb2
class Histogram(object):
- """Histogram class used for recording performance testing data.
+ """Histogram class used for recording performance testing data.
This class is thread safe.
"""
- def __init__(self, resolution, max_possible):
- self._lock = threading.Lock()
- self._resolution = resolution
- self._max_possible = max_possible
- self._sum = 0
- self._sum_of_squares = 0
- self.multiplier = 1.0 + self._resolution
- self._count = 0
- self._min = self._max_possible
- self._max = 0
- self._buckets = [0] * (self._bucket_for(self._max_possible) + 1)
+ def __init__(self, resolution, max_possible):
+ self._lock = threading.Lock()
+ self._resolution = resolution
+ self._max_possible = max_possible
+ self._sum = 0
+ self._sum_of_squares = 0
+ self.multiplier = 1.0 + self._resolution
+ self._count = 0
+ self._min = self._max_possible
+ self._max = 0
+ self._buckets = [0] * (self._bucket_for(self._max_possible) + 1)
- def reset(self):
- with self._lock:
- self._sum = 0
- self._sum_of_squares = 0
- self._count = 0
- self._min = self._max_possible
- self._max = 0
- self._buckets = [0] * (self._bucket_for(self._max_possible) + 1)
+ def reset(self):
+ with self._lock:
+ self._sum = 0
+ self._sum_of_squares = 0
+ self._count = 0
+ self._min = self._max_possible
+ self._max = 0
+ self._buckets = [0] * (self._bucket_for(self._max_possible) + 1)
- def add(self, val):
- with self._lock:
- self._sum += val
- self._sum_of_squares += val * val
- self._count += 1
- self._min = min(self._min, val)
- self._max = max(self._max, val)
- self._buckets[self._bucket_for(val)] += 1
+ def add(self, val):
+ with self._lock:
+ self._sum += val
+ self._sum_of_squares += val * val
+ self._count += 1
+ self._min = min(self._min, val)
+ self._max = max(self._max, val)
+ self._buckets[self._bucket_for(val)] += 1
- def get_data(self):
- with self._lock:
- data = stats_pb2.HistogramData()
- data.bucket.extend(self._buckets)
- data.min_seen = self._min
- data.max_seen = self._max
- data.sum = self._sum
- data.sum_of_squares = self._sum_of_squares
- data.count = self._count
- return data
+ def get_data(self):
+ with self._lock:
+ data = stats_pb2.HistogramData()
+ data.bucket.extend(self._buckets)
+ data.min_seen = self._min
+ data.max_seen = self._max
+ data.sum = self._sum
+ data.sum_of_squares = self._sum_of_squares
+ data.count = self._count
+ return data
- def _bucket_for(self, val):
- val = min(val, self._max_possible)
- return int(math.log(val, self.multiplier))
+ def _bucket_for(self, val):
+ val = min(val, self._max_possible)
+ return int(math.log(val, self.multiplier))
diff --git a/src/python/grpcio_tests/tests/qps/qps_worker.py b/src/python/grpcio_tests/tests/qps/qps_worker.py
index 2371ff0956..025dfb9d4a 100644
--- a/src/python/grpcio_tests/tests/qps/qps_worker.py
+++ b/src/python/grpcio_tests/tests/qps/qps_worker.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""The entry point for the qps worker."""
import argparse
@@ -40,22 +39,23 @@ from tests.qps import worker_server
def run_worker_server(port):
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
- servicer = worker_server.WorkerServer()
- services_pb2.add_WorkerServiceServicer_to_server(servicer, server)
- server.add_insecure_port('[::]:{}'.format(port))
- server.start()
- servicer.wait_for_quit()
- server.stop(0)
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=5))
+ servicer = worker_server.WorkerServer()
+ services_pb2.add_WorkerServiceServicer_to_server(servicer, server)
+ server.add_insecure_port('[::]:{}'.format(port))
+ server.start()
+ servicer.wait_for_quit()
+ server.stop(0)
if __name__ == '__main__':
- parser = argparse.ArgumentParser(
- description='gRPC Python performance testing worker')
- parser.add_argument('--driver_port',
- type=int,
- dest='port',
- help='The port the worker should listen on')
- args = parser.parse_args()
-
- run_worker_server(args.port)
+ parser = argparse.ArgumentParser(
+ description='gRPC Python performance testing worker')
+ parser.add_argument(
+ '--driver_port',
+ type=int,
+ dest='port',
+ help='The port the worker should listen on')
+ args = parser.parse_args()
+
+ run_worker_server(args.port)
diff --git a/src/python/grpcio_tests/tests/qps/worker_server.py b/src/python/grpcio_tests/tests/qps/worker_server.py
index 46d542940f..1deb7ed698 100644
--- a/src/python/grpcio_tests/tests/qps/worker_server.py
+++ b/src/python/grpcio_tests/tests/qps/worker_server.py
@@ -46,149 +46,156 @@ from tests.unit import resources
class WorkerServer(services_pb2.WorkerServiceServicer):
- """Python Worker Server implementation."""
-
- def __init__(self):
- self._quit_event = threading.Event()
-
- def RunServer(self, request_iterator, context):
- config = next(request_iterator).setup
- server, port = self._create_server(config)
- cores = multiprocessing.cpu_count()
- server.start()
- start_time = time.time()
- yield self._get_server_status(start_time, start_time, port, cores)
-
- for request in request_iterator:
- end_time = time.time()
- status = self._get_server_status(start_time, end_time, port, cores)
- if request.mark.reset:
- start_time = end_time
- yield status
- server.stop(None)
-
- def _get_server_status(self, start_time, end_time, port, cores):
- end_time = time.time()
- elapsed_time = end_time - start_time
- stats = stats_pb2.ServerStats(time_elapsed=elapsed_time,
- time_user=elapsed_time,
- time_system=elapsed_time)
- return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
-
- def _create_server(self, config):
- if config.async_server_threads == 0:
- # This is the default concurrent.futures thread pool size, but
- # None doesn't seem to work
- server_threads = multiprocessing.cpu_count() * 5
- else:
- server_threads = config.async_server_threads
- server = grpc.server(futures.ThreadPoolExecutor(
- max_workers=server_threads))
- if config.server_type == control_pb2.ASYNC_SERVER:
- servicer = benchmark_server.BenchmarkServer()
- services_pb2.add_BenchmarkServiceServicer_to_server(servicer, server)
- elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
- resp_size = config.payload_config.bytebuf_params.resp_size
- servicer = benchmark_server.GenericBenchmarkServer(resp_size)
- method_implementations = {
- 'StreamingCall':
- grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
- 'UnaryCall':
- grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
- }
- handler = grpc.method_handlers_generic_handler(
- 'grpc.testing.BenchmarkService', method_implementations)
- server.add_generic_rpc_handlers((handler,))
- else:
- raise Exception('Unsupported server type {}'.format(config.server_type))
-
- if config.HasField('security_params'): # Use SSL
- server_creds = grpc.ssl_server_credentials(
- ((resources.private_key(), resources.certificate_chain()),))
- port = server.add_secure_port('[::]:{}'.format(config.port), server_creds)
- else:
- port = server.add_insecure_port('[::]:{}'.format(config.port))
-
- return (server, port)
-
- def RunClient(self, request_iterator, context):
- config = next(request_iterator).setup
- client_runners = []
- qps_data = histogram.Histogram(config.histogram_params.resolution,
- config.histogram_params.max_possible)
- start_time = time.time()
-
- # Create a client for each channel
- for i in xrange(config.client_channels):
- server = config.server_targets[i % len(config.server_targets)]
- runner = self._create_client_runner(server, config, qps_data)
- client_runners.append(runner)
- runner.start()
-
- end_time = time.time()
- yield self._get_client_status(start_time, end_time, qps_data)
-
- # Respond to stat requests
- for request in request_iterator:
- end_time = time.time()
- status = self._get_client_status(start_time, end_time, qps_data)
- if request.mark.reset:
- qps_data.reset()
+ """Python Worker Server implementation."""
+
+ def __init__(self):
+ self._quit_event = threading.Event()
+
+ def RunServer(self, request_iterator, context):
+ config = next(request_iterator).setup
+ server, port = self._create_server(config)
+ cores = multiprocessing.cpu_count()
+ server.start()
start_time = time.time()
- yield status
-
- # Cleanup the clients
- for runner in client_runners:
- runner.stop()
-
- def _get_client_status(self, start_time, end_time, qps_data):
- latencies = qps_data.get_data()
- end_time = time.time()
- elapsed_time = end_time - start_time
- stats = stats_pb2.ClientStats(latencies=latencies,
- time_elapsed=elapsed_time,
- time_user=elapsed_time,
- time_system=elapsed_time)
- return control_pb2.ClientStatus(stats=stats)
-
- def _create_client_runner(self, server, config, qps_data):
- if config.client_type == control_pb2.SYNC_CLIENT:
- if config.rpc_type == control_pb2.UNARY:
- client = benchmark_client.UnarySyncBenchmarkClient(
- server, config, qps_data)
- elif config.rpc_type == control_pb2.STREAMING:
- client = benchmark_client.StreamingSyncBenchmarkClient(
- server, config, qps_data)
- elif config.client_type == control_pb2.ASYNC_CLIENT:
- if config.rpc_type == control_pb2.UNARY:
- client = benchmark_client.UnaryAsyncBenchmarkClient(
- server, config, qps_data)
- else:
- raise Exception('Async streaming client not supported')
- else:
- raise Exception('Unsupported client type {}'.format(config.client_type))
-
- # In multi-channel tests, we split the load across all channels
- load_factor = float(config.client_channels)
- if config.load_params.WhichOneof('load') == 'closed_loop':
- runner = client_runner.ClosedLoopClientRunner(
- client, config.outstanding_rpcs_per_channel)
- else: # Open loop Poisson
- alpha = config.load_params.poisson.offered_load / load_factor
- def poisson():
- while True:
- yield random.expovariate(alpha)
-
- runner = client_runner.OpenLoopClientRunner(client, poisson())
-
- return runner
-
- def CoreCount(self, request, context):
- return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
-
- def QuitWorker(self, request, context):
- self._quit_event.set()
- return control_pb2.Void()
-
- def wait_for_quit(self):
- self._quit_event.wait()
+ yield self._get_server_status(start_time, start_time, port, cores)
+
+ for request in request_iterator:
+ end_time = time.time()
+ status = self._get_server_status(start_time, end_time, port, cores)
+ if request.mark.reset:
+ start_time = end_time
+ yield status
+ server.stop(None)
+
+ def _get_server_status(self, start_time, end_time, port, cores):
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ stats = stats_pb2.ServerStats(
+ time_elapsed=elapsed_time,
+ time_user=elapsed_time,
+ time_system=elapsed_time)
+ return control_pb2.ServerStatus(stats=stats, port=port, cores=cores)
+
+ def _create_server(self, config):
+ if config.async_server_threads == 0:
+ # This is the default concurrent.futures thread pool size, but
+ # None doesn't seem to work
+ server_threads = multiprocessing.cpu_count() * 5
+ else:
+ server_threads = config.async_server_threads
+ server = grpc.server(
+ futures.ThreadPoolExecutor(max_workers=server_threads))
+ if config.server_type == control_pb2.ASYNC_SERVER:
+ servicer = benchmark_server.BenchmarkServer()
+ services_pb2.add_BenchmarkServiceServicer_to_server(servicer,
+ server)
+ elif config.server_type == control_pb2.ASYNC_GENERIC_SERVER:
+ resp_size = config.payload_config.bytebuf_params.resp_size
+ servicer = benchmark_server.GenericBenchmarkServer(resp_size)
+ method_implementations = {
+ 'StreamingCall':
+ grpc.stream_stream_rpc_method_handler(servicer.StreamingCall),
+ 'UnaryCall':
+ grpc.unary_unary_rpc_method_handler(servicer.UnaryCall),
+ }
+ handler = grpc.method_handlers_generic_handler(
+ 'grpc.testing.BenchmarkService', method_implementations)
+ server.add_generic_rpc_handlers((handler,))
+ else:
+ raise Exception('Unsupported server type {}'.format(
+ config.server_type))
+
+ if config.HasField('security_params'): # Use SSL
+ server_creds = grpc.ssl_server_credentials((
+ (resources.private_key(), resources.certificate_chain()),))
+ port = server.add_secure_port('[::]:{}'.format(config.port),
+ server_creds)
+ else:
+ port = server.add_insecure_port('[::]:{}'.format(config.port))
+
+ return (server, port)
+
+ def RunClient(self, request_iterator, context):
+ config = next(request_iterator).setup
+ client_runners = []
+ qps_data = histogram.Histogram(config.histogram_params.resolution,
+ config.histogram_params.max_possible)
+ start_time = time.time()
+
+ # Create a client for each channel
+ for i in xrange(config.client_channels):
+ server = config.server_targets[i % len(config.server_targets)]
+ runner = self._create_client_runner(server, config, qps_data)
+ client_runners.append(runner)
+ runner.start()
+
+ end_time = time.time()
+ yield self._get_client_status(start_time, end_time, qps_data)
+
+ # Respond to stat requests
+ for request in request_iterator:
+ end_time = time.time()
+ status = self._get_client_status(start_time, end_time, qps_data)
+ if request.mark.reset:
+ qps_data.reset()
+ start_time = time.time()
+ yield status
+
+ # Cleanup the clients
+ for runner in client_runners:
+ runner.stop()
+
+ def _get_client_status(self, start_time, end_time, qps_data):
+ latencies = qps_data.get_data()
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ stats = stats_pb2.ClientStats(
+ latencies=latencies,
+ time_elapsed=elapsed_time,
+ time_user=elapsed_time,
+ time_system=elapsed_time)
+ return control_pb2.ClientStatus(stats=stats)
+
+ def _create_client_runner(self, server, config, qps_data):
+ if config.client_type == control_pb2.SYNC_CLIENT:
+ if config.rpc_type == control_pb2.UNARY:
+ client = benchmark_client.UnarySyncBenchmarkClient(
+ server, config, qps_data)
+ elif config.rpc_type == control_pb2.STREAMING:
+ client = benchmark_client.StreamingSyncBenchmarkClient(
+ server, config, qps_data)
+ elif config.client_type == control_pb2.ASYNC_CLIENT:
+ if config.rpc_type == control_pb2.UNARY:
+ client = benchmark_client.UnaryAsyncBenchmarkClient(
+ server, config, qps_data)
+ else:
+ raise Exception('Async streaming client not supported')
+ else:
+ raise Exception('Unsupported client type {}'.format(
+ config.client_type))
+
+ # In multi-channel tests, we split the load across all channels
+ load_factor = float(config.client_channels)
+ if config.load_params.WhichOneof('load') == 'closed_loop':
+ runner = client_runner.ClosedLoopClientRunner(
+ client, config.outstanding_rpcs_per_channel)
+ else: # Open loop Poisson
+ alpha = config.load_params.poisson.offered_load / load_factor
+
+ def poisson():
+ while True:
+ yield random.expovariate(alpha)
+
+ runner = client_runner.OpenLoopClientRunner(client, poisson())
+
+ return runner
+
+ def CoreCount(self, request, context):
+ return control_pb2.CoreResponse(cores=multiprocessing.cpu_count())
+
+ def QuitWorker(self, request, context):
+ self._quit_event.set()
+ return control_pb2.Void()
+
+ def wait_for_quit(self):
+ self._quit_event.wait()
diff --git a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
index 43d6c971b5..76e89ca039 100644
--- a/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
+++ b/src/python/grpcio_tests/tests/reflection/_reflection_servicer_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc_reflection.v1alpha.reflection."""
import unittest
@@ -45,141 +44,112 @@ from tests.unit.framework.common import test_constants
_EMPTY_PROTO_FILE_NAME = 'src/proto/grpc/testing/empty.proto'
_EMPTY_PROTO_SYMBOL_NAME = 'grpc.testing.Empty'
-_SERVICE_NAMES = (
- 'Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman', 'Galilei')
+_SERVICE_NAMES = ('Angstrom', 'Bohr', 'Curie', 'Dyson', 'Einstein', 'Feynman',
+ 'Galilei')
+
def _file_descriptor_to_proto(descriptor):
- proto = descriptor_pb2.FileDescriptorProto()
- descriptor.CopyToProto(proto)
- return proto.SerializeToString()
+ proto = descriptor_pb2.FileDescriptorProto()
+ descriptor.CopyToProto(proto)
+ return proto.SerializeToString()
+
class ReflectionServicerTest(unittest.TestCase):
- def setUp(self):
- servicer = reflection.ReflectionServicer(service_names=_SERVICE_NAMES)
- server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(server_pool)
- port = self._server.add_insecure_port('[::]:0')
- reflection_pb2.add_ServerReflectionServicer_to_server(servicer, self._server)
- self._server.start()
-
- channel = grpc.insecure_channel('localhost:%d' % port)
- self._stub = reflection_pb2.ServerReflectionStub(channel)
-
- def testFileByName(self):
- requests = (
- reflection_pb2.ServerReflectionRequest(
- file_by_filename=_EMPTY_PROTO_FILE_NAME
- ),
- reflection_pb2.ServerReflectionRequest(
- file_by_filename='i-donut-exist'
- ),
- )
- responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
- expected_responses = (
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- file_descriptor_response=reflection_pb2.FileDescriptorResponse(
- file_descriptor_proto=(
- _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
- )
- )
- ),
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- error_response=reflection_pb2.ErrorResponse(
- error_code=grpc.StatusCode.NOT_FOUND.value[0],
- error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
- )
- ),
- )
- self.assertSequenceEqual(expected_responses, responses)
-
- def testFileBySymbol(self):
- requests = (
- reflection_pb2.ServerReflectionRequest(
- file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME
- ),
- reflection_pb2.ServerReflectionRequest(
- file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'
- ),
- )
- responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
- expected_responses = (
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- file_descriptor_response=reflection_pb2.FileDescriptorResponse(
- file_descriptor_proto=(
- _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),
- )
- )
- ),
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- error_response=reflection_pb2.ErrorResponse(
- error_code=grpc.StatusCode.NOT_FOUND.value[0],
- error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
- )
- ),
- )
- self.assertSequenceEqual(expected_responses, responses)
-
- @unittest.skip('TODO(atash): implement file-containing-extension reflection '
- '(see https://github.com/google/protobuf/issues/2248)')
- def testFileContainingExtension(self):
- requests = (
- reflection_pb2.ServerReflectionRequest(
- file_containing_extension=reflection_pb2.ExtensionRequest(
- containing_type='grpc.testing.proto2.Empty',
- extension_number=125,
- ),
- ),
- reflection_pb2.ServerReflectionRequest(
- file_containing_extension=reflection_pb2.ExtensionRequest(
- containing_type='i.donut.exist.co.uk.org.net.me.name.foo',
- extension_number=55,
- ),
- ),
- )
- responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
- expected_responses = (
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- file_descriptor_response=reflection_pb2.FileDescriptorResponse(
- file_descriptor_proto=(
- _file_descriptor_to_proto(empty_extensions_pb2.DESCRIPTOR),
- )
- )
- ),
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- error_response=reflection_pb2.ErrorResponse(
- error_code=grpc.StatusCode.NOT_FOUND.value[0],
- error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
- )
- ),
- )
- self.assertSequenceEqual(expected_responses, responses)
-
- def testListServices(self):
- requests = (
- reflection_pb2.ServerReflectionRequest(
- list_services='',
- ),
- )
- responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
- expected_responses = (
- reflection_pb2.ServerReflectionResponse(
- valid_host='',
- list_services_response=reflection_pb2.ListServiceResponse(
- service=tuple(
- reflection_pb2.ServiceResponse(name=name)
- for name in _SERVICE_NAMES
- )
- )
- ),
- )
- self.assertSequenceEqual(expected_responses, responses)
+ def setUp(self):
+ servicer = reflection.ReflectionServicer(service_names=_SERVICE_NAMES)
+ server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ reflection_pb2.add_ServerReflectionServicer_to_server(servicer,
+ self._server)
+ self._server.start()
+
+ channel = grpc.insecure_channel('localhost:%d' % port)
+ self._stub = reflection_pb2.ServerReflectionStub(channel)
+
+ def testFileByName(self):
+ requests = (
+ reflection_pb2.ServerReflectionRequest(
+ file_by_filename=_EMPTY_PROTO_FILE_NAME),
+ reflection_pb2.ServerReflectionRequest(
+ file_by_filename='i-donut-exist'),)
+ responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
+ expected_responses = (
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ file_descriptor_response=reflection_pb2.FileDescriptorResponse(
+ file_descriptor_proto=(
+ _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))),
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ error_response=reflection_pb2.ErrorResponse(
+ error_code=grpc.StatusCode.NOT_FOUND.value[0],
+ error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
+ )),)
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testFileBySymbol(self):
+ requests = (
+ reflection_pb2.ServerReflectionRequest(
+ file_containing_symbol=_EMPTY_PROTO_SYMBOL_NAME),
+ reflection_pb2.ServerReflectionRequest(
+ file_containing_symbol='i.donut.exist.co.uk.org.net.me.name.foo'
+ ),)
+ responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
+ expected_responses = (
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ file_descriptor_response=reflection_pb2.FileDescriptorResponse(
+ file_descriptor_proto=(
+ _file_descriptor_to_proto(empty_pb2.DESCRIPTOR),))),
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ error_response=reflection_pb2.ErrorResponse(
+ error_code=grpc.StatusCode.NOT_FOUND.value[0],
+ error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
+ )),)
+ self.assertSequenceEqual(expected_responses, responses)
+
+ @unittest.skip(
+ 'TODO(atash): implement file-containing-extension reflection '
+ '(see https://github.com/google/protobuf/issues/2248)')
+ def testFileContainingExtension(self):
+ requests = (
+ reflection_pb2.ServerReflectionRequest(
+ file_containing_extension=reflection_pb2.ExtensionRequest(
+ containing_type='grpc.testing.proto2.Empty',
+ extension_number=125,),),
+ reflection_pb2.ServerReflectionRequest(
+ file_containing_extension=reflection_pb2.ExtensionRequest(
+ containing_type='i.donut.exist.co.uk.org.net.me.name.foo',
+ extension_number=55,),),)
+ responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
+ expected_responses = (
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ file_descriptor_response=reflection_pb2.FileDescriptorResponse(
+ file_descriptor_proto=(_file_descriptor_to_proto(
+ empty_extensions_pb2.DESCRIPTOR),))),
+ reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ error_response=reflection_pb2.ErrorResponse(
+ error_code=grpc.StatusCode.NOT_FOUND.value[0],
+ error_message=grpc.StatusCode.NOT_FOUND.value[1].encode(),
+ )),)
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testListServices(self):
+ requests = (reflection_pb2.ServerReflectionRequest(list_services='',),)
+ responses = tuple(self._stub.ServerReflectionInfo(iter(requests)))
+ expected_responses = (reflection_pb2.ServerReflectionResponse(
+ valid_host='',
+ list_services_response=reflection_pb2.ListServiceResponse(
+ service=tuple(
+ reflection_pb2.ServiceResponse(name=name)
+ for name in _SERVICE_NAMES))),)
+ self.assertSequenceEqual(expected_responses, responses)
+
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/stress/client.py b/src/python/grpcio_tests/tests/stress/client.py
index b8116729b5..61f9e1c6b1 100644
--- a/src/python/grpcio_tests/tests/stress/client.py
+++ b/src/python/grpcio_tests/tests/stress/client.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Entry point for running stress tests."""
import argparse
@@ -46,118 +45,132 @@ from tests.stress import test_runner
def _args():
- parser = argparse.ArgumentParser(description='gRPC Python stress test client')
- parser.add_argument(
- '--server_addresses',
- help='comma seperated list of hostname:port to run servers on',
- default='localhost:8080', type=str)
- parser.add_argument(
- '--test_cases',
- help='comma seperated list of testcase:weighting of tests to run',
- default='large_unary:100',
- type=str)
- parser.add_argument(
- '--test_duration_secs',
- help='number of seconds to run the stress test',
- default=-1, type=int)
- parser.add_argument(
- '--num_channels_per_server',
- help='number of channels per server',
- default=1, type=int)
- parser.add_argument(
- '--num_stubs_per_channel',
- help='number of stubs to create per channel',
- default=1, type=int)
- parser.add_argument(
- '--metrics_port',
- help='the port to listen for metrics requests on',
- default=8081, type=int)
- parser.add_argument(
- '--use_test_ca',
- help='Whether to use our fake CA. Requires --use_tls=true',
- default=False, type=bool)
- parser.add_argument(
- '--use_tls',
- help='Whether to use TLS', default=False, type=bool)
- parser.add_argument(
- '--server_host_override', default="foo.test.google.fr",
- help='the server host to which to claim to connect', type=str)
- return parser.parse_args()
+ parser = argparse.ArgumentParser(
+ description='gRPC Python stress test client')
+ parser.add_argument(
+ '--server_addresses',
+ help='comma seperated list of hostname:port to run servers on',
+ default='localhost:8080',
+ type=str)
+ parser.add_argument(
+ '--test_cases',
+ help='comma seperated list of testcase:weighting of tests to run',
+ default='large_unary:100',
+ type=str)
+ parser.add_argument(
+ '--test_duration_secs',
+ help='number of seconds to run the stress test',
+ default=-1,
+ type=int)
+ parser.add_argument(
+ '--num_channels_per_server',
+ help='number of channels per server',
+ default=1,
+ type=int)
+ parser.add_argument(
+ '--num_stubs_per_channel',
+ help='number of stubs to create per channel',
+ default=1,
+ type=int)
+ parser.add_argument(
+ '--metrics_port',
+ help='the port to listen for metrics requests on',
+ default=8081,
+ type=int)
+ parser.add_argument(
+ '--use_test_ca',
+ help='Whether to use our fake CA. Requires --use_tls=true',
+ default=False,
+ type=bool)
+ parser.add_argument(
+ '--use_tls', help='Whether to use TLS', default=False, type=bool)
+ parser.add_argument(
+ '--server_host_override',
+ default="foo.test.google.fr",
+ help='the server host to which to claim to connect',
+ type=str)
+ return parser.parse_args()
def _test_case_from_arg(test_case_arg):
- for test_case in methods.TestCase:
- if test_case_arg == test_case.value:
- return test_case
- else:
- raise ValueError('No test case {}!'.format(test_case_arg))
+ for test_case in methods.TestCase:
+ if test_case_arg == test_case.value:
+ return test_case
+ else:
+ raise ValueError('No test case {}!'.format(test_case_arg))
def _parse_weighted_test_cases(test_case_args):
- weighted_test_cases = {}
- for test_case_arg in test_case_args.split(','):
- name, weight = test_case_arg.split(':', 1)
- test_case = _test_case_from_arg(name)
- weighted_test_cases[test_case] = int(weight)
- return weighted_test_cases
+ weighted_test_cases = {}
+ for test_case_arg in test_case_args.split(','):
+ name, weight = test_case_arg.split(':', 1)
+ test_case = _test_case_from_arg(name)
+ weighted_test_cases[test_case] = int(weight)
+ return weighted_test_cases
+
def _get_channel(target, args):
- if args.use_tls:
- if args.use_test_ca:
- root_certificates = resources.test_root_certificates()
+ if args.use_tls:
+ if args.use_test_ca:
+ root_certificates = resources.test_root_certificates()
+ else:
+ root_certificates = None # will load default roots.
+ channel_credentials = grpc.ssl_channel_credentials(
+ root_certificates=root_certificates)
+ options = ((
+ 'grpc.ssl_target_name_override',
+ args.server_host_override,),)
+ channel = grpc.secure_channel(
+ target, channel_credentials, options=options)
else:
- root_certificates = None # will load default roots.
- channel_credentials = grpc.ssl_channel_credentials(
- root_certificates=root_certificates)
- options = (('grpc.ssl_target_name_override', args.server_host_override,),)
- channel = grpc.secure_channel(target, channel_credentials, options=options)
- else:
- channel = grpc.insecure_channel(target)
-
- # waits for the channel to be ready before we start sending messages
- grpc.channel_ready_future(channel).result()
- return channel
+ channel = grpc.insecure_channel(target)
+
+ # waits for the channel to be ready before we start sending messages
+ grpc.channel_ready_future(channel).result()
+ return channel
+
def run_test(args):
- test_cases = _parse_weighted_test_cases(args.test_cases)
- test_server_targets = args.server_addresses.split(',')
- # Propagate any client exceptions with a queue
- exception_queue = queue.Queue()
- stop_event = threading.Event()
- hist = histogram.Histogram(1, 1)
- runners = []
-
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
- metrics_pb2.add_MetricsServiceServicer_to_server(
- metrics_server.MetricsServer(hist), server)
- server.add_insecure_port('[::]:{}'.format(args.metrics_port))
- server.start()
-
- for test_server_target in test_server_targets:
- for _ in xrange(args.num_channels_per_server):
- channel = _get_channel(test_server_target, args)
- for _ in xrange(args.num_stubs_per_channel):
- stub = test_pb2.TestServiceStub(channel)
- runner = test_runner.TestRunner(stub, test_cases, hist,
- exception_queue, stop_event)
- runners.append(runner)
-
- for runner in runners:
- runner.start()
- try:
- timeout_secs = args.test_duration_secs
- if timeout_secs < 0:
- timeout_secs = None
- raise exception_queue.get(block=True, timeout=timeout_secs)
- except queue.Empty:
- # No exceptions thrown, success
- pass
- finally:
- stop_event.set()
+ test_cases = _parse_weighted_test_cases(args.test_cases)
+ test_server_targets = args.server_addresses.split(',')
+ # Propagate any client exceptions with a queue
+ exception_queue = queue.Queue()
+ stop_event = threading.Event()
+ hist = histogram.Histogram(1, 1)
+ runners = []
+
+ server = grpc.server(futures.ThreadPoolExecutor(max_workers=25))
+ metrics_pb2.add_MetricsServiceServicer_to_server(
+ metrics_server.MetricsServer(hist), server)
+ server.add_insecure_port('[::]:{}'.format(args.metrics_port))
+ server.start()
+
+ for test_server_target in test_server_targets:
+ for _ in xrange(args.num_channels_per_server):
+ channel = _get_channel(test_server_target, args)
+ for _ in xrange(args.num_stubs_per_channel):
+ stub = test_pb2.TestServiceStub(channel)
+ runner = test_runner.TestRunner(stub, test_cases, hist,
+ exception_queue, stop_event)
+ runners.append(runner)
+
for runner in runners:
- runner.join()
- runner = None
- server.stop(None)
+ runner.start()
+ try:
+ timeout_secs = args.test_duration_secs
+ if timeout_secs < 0:
+ timeout_secs = None
+ raise exception_queue.get(block=True, timeout=timeout_secs)
+ except queue.Empty:
+ # No exceptions thrown, success
+ pass
+ finally:
+ stop_event.set()
+ for runner in runners:
+ runner.join()
+ runner = None
+ server.stop(None)
+
if __name__ == '__main__':
- run_test(_args())
+ run_test(_args())
diff --git a/src/python/grpcio_tests/tests/stress/metrics_server.py b/src/python/grpcio_tests/tests/stress/metrics_server.py
index 33dd1d6f2a..3a4cbc27ba 100644
--- a/src/python/grpcio_tests/tests/stress/metrics_server.py
+++ b/src/python/grpcio_tests/tests/stress/metrics_server.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""MetricsService for publishing stress test qps data."""
import time
@@ -38,23 +37,23 @@ GAUGE_NAME = 'python_overall_qps'
class MetricsServer(metrics_pb2.MetricsServiceServicer):
- def __init__(self, histogram):
- self._start_time = time.time()
- self._histogram = histogram
-
- def _get_qps(self):
- count = self._histogram.get_data().count
- delta = time.time() - self._start_time
- self._histogram.reset()
- self._start_time = time.time()
- return int(count/delta)
-
- def GetAllGauges(self, request, context):
- qps = self._get_qps()
- return [metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps)]
-
- def GetGauge(self, request, context):
- if request.name != GAUGE_NAME:
- raise Exception('Gauge {} does not exist'.format(request.name))
- qps = self._get_qps()
- return metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps)
+ def __init__(self, histogram):
+ self._start_time = time.time()
+ self._histogram = histogram
+
+ def _get_qps(self):
+ count = self._histogram.get_data().count
+ delta = time.time() - self._start_time
+ self._histogram.reset()
+ self._start_time = time.time()
+ return int(count / delta)
+
+ def GetAllGauges(self, request, context):
+ qps = self._get_qps()
+ return [metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps)]
+
+ def GetGauge(self, request, context):
+ if request.name != GAUGE_NAME:
+ raise Exception('Gauge {} does not exist'.format(request.name))
+ qps = self._get_qps()
+ return metrics_pb2.GaugeResponse(name=GAUGE_NAME, long_value=qps)
diff --git a/src/python/grpcio_tests/tests/stress/test_runner.py b/src/python/grpcio_tests/tests/stress/test_runner.py
index 88f13727e3..258abe9c21 100644
--- a/src/python/grpcio_tests/tests/stress/test_runner.py
+++ b/src/python/grpcio_tests/tests/stress/test_runner.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Thread that sends random weighted requests on a TestService stub."""
import random
@@ -36,38 +35,38 @@ import traceback
def _weighted_test_case_generator(weighted_cases):
- weight_sum = sum(weighted_cases.itervalues())
+ weight_sum = sum(weighted_cases.itervalues())
- while True:
- val = random.uniform(0, weight_sum)
- partial_sum = 0
- for case in weighted_cases:
- partial_sum += weighted_cases[case]
- if val <= partial_sum:
- yield case
- break
+ while True:
+ val = random.uniform(0, weight_sum)
+ partial_sum = 0
+ for case in weighted_cases:
+ partial_sum += weighted_cases[case]
+ if val <= partial_sum:
+ yield case
+ break
class TestRunner(threading.Thread):
- def __init__(self, stub, test_cases, hist, exception_queue, stop_event):
- super(TestRunner, self).__init__()
- self._exception_queue = exception_queue
- self._stop_event = stop_event
- self._stub = stub
- self._test_cases = _weighted_test_case_generator(test_cases)
- self._histogram = hist
+ def __init__(self, stub, test_cases, hist, exception_queue, stop_event):
+ super(TestRunner, self).__init__()
+ self._exception_queue = exception_queue
+ self._stop_event = stop_event
+ self._stub = stub
+ self._test_cases = _weighted_test_case_generator(test_cases)
+ self._histogram = hist
- def run(self):
- while not self._stop_event.is_set():
- try:
- test_case = next(self._test_cases)
- start_time = time.time()
- test_case.test_interoperability(self._stub, None)
- end_time = time.time()
- self._histogram.add((end_time - start_time)*1e9)
- except Exception as e:
- traceback.print_exc()
- self._exception_queue.put(
- Exception("An exception occured during test {}"
- .format(test_case), e))
+ def run(self):
+ while not self._stop_event.is_set():
+ try:
+ test_case = next(self._test_cases)
+ start_time = time.time()
+ test_case.test_interoperability(self._stub, None)
+ end_time = time.time()
+ self._histogram.add((end_time - start_time) * 1e9)
+ except Exception as e:
+ traceback.print_exc()
+ self._exception_queue.put(
+ Exception("An exception occured during test {}"
+ .format(test_case), e))
diff --git a/src/python/grpcio_tests/tests/unit/__init__.py b/src/python/grpcio_tests/tests/unit/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/_api_test.py b/src/python/grpcio_tests/tests/unit/_api_test.py
index 51dc425420..5435c5500c 100644
--- a/src/python/grpcio_tests/tests/unit/_api_test.py
+++ b/src/python/grpcio_tests/tests/unit/_api_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test of gRPC Python's application-layer API."""
import unittest
@@ -40,73 +39,71 @@ from tests.unit import _from_grpc_import_star
class AllTest(unittest.TestCase):
- def testAll(self):
- expected_grpc_code_elements = (
- 'FutureTimeoutError',
- 'FutureCancelledError',
- 'Future',
- 'ChannelConnectivity',
- 'StatusCode',
- 'RpcError',
- 'RpcContext',
- 'Call',
- 'ChannelCredentials',
- 'CallCredentials',
- 'AuthMetadataContext',
- 'AuthMetadataPluginCallback',
- 'AuthMetadataPlugin',
- 'ServerCredentials',
- 'UnaryUnaryMultiCallable',
- 'UnaryStreamMultiCallable',
- 'StreamUnaryMultiCallable',
- 'StreamStreamMultiCallable',
- 'Channel',
- 'ServicerContext',
- 'RpcMethodHandler',
- 'HandlerCallDetails',
- 'GenericRpcHandler',
- 'ServiceRpcHandler',
- 'Server',
- 'unary_unary_rpc_method_handler',
- 'unary_stream_rpc_method_handler',
- 'stream_unary_rpc_method_handler',
- 'stream_stream_rpc_method_handler',
- 'method_handlers_generic_handler',
- 'ssl_channel_credentials',
- 'metadata_call_credentials',
- 'access_token_call_credentials',
- 'composite_call_credentials',
- 'composite_channel_credentials',
- 'ssl_server_credentials',
- 'channel_ready_future',
- 'insecure_channel',
- 'secure_channel',
- 'server',
- )
-
- six.assertCountEqual(
- self, expected_grpc_code_elements,
- _from_grpc_import_star.GRPC_ELEMENTS)
+ def testAll(self):
+ expected_grpc_code_elements = (
+ 'FutureTimeoutError',
+ 'FutureCancelledError',
+ 'Future',
+ 'ChannelConnectivity',
+ 'StatusCode',
+ 'RpcError',
+ 'RpcContext',
+ 'Call',
+ 'ChannelCredentials',
+ 'CallCredentials',
+ 'AuthMetadataContext',
+ 'AuthMetadataPluginCallback',
+ 'AuthMetadataPlugin',
+ 'ServerCredentials',
+ 'UnaryUnaryMultiCallable',
+ 'UnaryStreamMultiCallable',
+ 'StreamUnaryMultiCallable',
+ 'StreamStreamMultiCallable',
+ 'Channel',
+ 'ServicerContext',
+ 'RpcMethodHandler',
+ 'HandlerCallDetails',
+ 'GenericRpcHandler',
+ 'ServiceRpcHandler',
+ 'Server',
+ 'unary_unary_rpc_method_handler',
+ 'unary_stream_rpc_method_handler',
+ 'stream_unary_rpc_method_handler',
+ 'stream_stream_rpc_method_handler',
+ 'method_handlers_generic_handler',
+ 'ssl_channel_credentials',
+ 'metadata_call_credentials',
+ 'access_token_call_credentials',
+ 'composite_call_credentials',
+ 'composite_channel_credentials',
+ 'ssl_server_credentials',
+ 'channel_ready_future',
+ 'insecure_channel',
+ 'secure_channel',
+ 'server',)
+
+ six.assertCountEqual(self, expected_grpc_code_elements,
+ _from_grpc_import_star.GRPC_ELEMENTS)
class ChannelConnectivityTest(unittest.TestCase):
- def testChannelConnectivity(self):
- self.assertSequenceEqual(
- (grpc.ChannelConnectivity.IDLE,
- grpc.ChannelConnectivity.CONNECTING,
- grpc.ChannelConnectivity.READY,
- grpc.ChannelConnectivity.TRANSIENT_FAILURE,
- grpc.ChannelConnectivity.SHUTDOWN,),
- tuple(grpc.ChannelConnectivity))
+ def testChannelConnectivity(self):
+ self.assertSequenceEqual((
+ grpc.ChannelConnectivity.IDLE,
+ grpc.ChannelConnectivity.CONNECTING,
+ grpc.ChannelConnectivity.READY,
+ grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ grpc.ChannelConnectivity.SHUTDOWN,),
+ tuple(grpc.ChannelConnectivity))
class ChannelTest(unittest.TestCase):
- def test_secure_channel(self):
- channel_credentials = grpc.ssl_channel_credentials()
- channel = grpc.secure_channel('google.com:443', channel_credentials)
+ def test_secure_channel(self):
+ channel_credentials = grpc.ssl_channel_credentials()
+ channel = grpc.secure_channel('google.com:443', channel_credentials)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_auth_test.py b/src/python/grpcio_tests/tests/unit/_auth_test.py
index c31f7b06f7..52bd1cb7ba 100644
--- a/src/python/grpcio_tests/tests/unit/_auth_test.py
+++ b/src/python/grpcio_tests/tests/unit/_auth_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of standard AuthMetadataPlugins."""
import collections
@@ -38,59 +37,59 @@ from grpc import _auth
class MockGoogleCreds(object):
- def get_access_token(self):
- token = collections.namedtuple('MockAccessTokenInfo',
- ('access_token', 'expires_in'))
- token.access_token = 'token'
- return token
+ def get_access_token(self):
+ token = collections.namedtuple('MockAccessTokenInfo',
+ ('access_token', 'expires_in'))
+ token.access_token = 'token'
+ return token
class MockExceptionGoogleCreds(object):
- def get_access_token(self):
- raise Exception()
+ def get_access_token(self):
+ raise Exception()
class GoogleCallCredentialsTest(unittest.TestCase):
- def test_google_call_credentials_success(self):
- callback_event = threading.Event()
+ def test_google_call_credentials_success(self):
+ callback_event = threading.Event()
- def mock_callback(metadata, error):
- self.assertEqual(metadata, (('authorization', 'Bearer token'),))
- self.assertIsNone(error)
- callback_event.set()
+ def mock_callback(metadata, error):
+ self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+ self.assertIsNone(error)
+ callback_event.set()
- call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
- call_creds(None, mock_callback)
- self.assertTrue(callback_event.wait(1.0))
+ call_creds = _auth.GoogleCallCredentials(MockGoogleCreds())
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
- def test_google_call_credentials_error(self):
- callback_event = threading.Event()
+ def test_google_call_credentials_error(self):
+ callback_event = threading.Event()
- def mock_callback(metadata, error):
- self.assertIsNotNone(error)
- callback_event.set()
+ def mock_callback(metadata, error):
+ self.assertIsNotNone(error)
+ callback_event.set()
- call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
- call_creds(None, mock_callback)
- self.assertTrue(callback_event.wait(1.0))
+ call_creds = _auth.GoogleCallCredentials(MockExceptionGoogleCreds())
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
class AccessTokenCallCredentialsTest(unittest.TestCase):
- def test_google_call_credentials_success(self):
- callback_event = threading.Event()
+ def test_google_call_credentials_success(self):
+ callback_event = threading.Event()
- def mock_callback(metadata, error):
- self.assertEqual(metadata, (('authorization', 'Bearer token'),))
- self.assertIsNone(error)
- callback_event.set()
+ def mock_callback(metadata, error):
+ self.assertEqual(metadata, (('authorization', 'Bearer token'),))
+ self.assertIsNone(error)
+ callback_event.set()
- call_creds = _auth.AccessTokenCallCredentials('token')
- call_creds(None, mock_callback)
- self.assertTrue(callback_event.wait(1.0))
+ call_creds = _auth.AccessTokenCallCredentials('token')
+ call_creds(None, mock_callback)
+ self.assertTrue(callback_event.wait(1.0))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_channel_args_test.py b/src/python/grpcio_tests/tests/unit/_channel_args_test.py
index b46497afd6..845db777a4 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_args_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_args_test.py
@@ -26,17 +26,17 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of Channel Args on client/server side."""
import unittest
import grpc
+
class TestPointerWrapper(object):
- def __int__(self):
- return 123456
+ def __int__(self):
+ return 123456
TEST_CHANNEL_ARGS = (
@@ -44,17 +44,17 @@ TEST_CHANNEL_ARGS = (
('arg2', 'str_val'),
('arg3', 1),
(b'arg4', 'str_val'),
- ('arg6', TestPointerWrapper()),
-)
+ ('arg6', TestPointerWrapper()),)
class ChannelArgsTest(unittest.TestCase):
- def test_client(self):
- grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS)
+ def test_client(self):
+ grpc.insecure_channel('localhost:8080', options=TEST_CHANNEL_ARGS)
+
+ def test_server(self):
+ grpc.server(None, options=TEST_CHANNEL_ARGS)
- def test_server(self):
- grpc.server(None, options=TEST_CHANNEL_ARGS)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
index 3d9dd17ff6..d67693154b 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_connectivity_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc._channel.Channel connectivity."""
import threading
@@ -39,125 +38,123 @@ from tests.unit import _thread_pool
def _ready_in_connectivities(connectivities):
- return grpc.ChannelConnectivity.READY in connectivities
+ return grpc.ChannelConnectivity.READY in connectivities
def _last_connectivity_is_not_ready(connectivities):
- return connectivities[-1] is not grpc.ChannelConnectivity.READY
+ return connectivities[-1] is not grpc.ChannelConnectivity.READY
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._connectivities = []
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._connectivities = []
- def update(self, connectivity):
- with self._condition:
- self._connectivities.append(connectivity)
- self._condition.notify()
+ def update(self, connectivity):
+ with self._condition:
+ self._connectivities.append(connectivity)
+ self._condition.notify()
- def connectivities(self):
- with self._condition:
- return tuple(self._connectivities)
+ def connectivities(self):
+ with self._condition:
+ return tuple(self._connectivities)
- def block_until_connectivities_satisfy(self, predicate):
- with self._condition:
- while True:
- connectivities = tuple(self._connectivities)
- if predicate(connectivities):
- return connectivities
- else:
- self._condition.wait()
+ def block_until_connectivities_satisfy(self, predicate):
+ with self._condition:
+ while True:
+ connectivities = tuple(self._connectivities)
+ if predicate(connectivities):
+ return connectivities
+ else:
+ self._condition.wait()
class ChannelConnectivityTest(unittest.TestCase):
- def test_lonely_channel_connectivity(self):
- callback = _Callback()
-
- channel = grpc.insecure_channel('localhost:12345')
- channel.subscribe(callback.update, try_to_connect=False)
- first_connectivities = callback.block_until_connectivities_satisfy(bool)
- channel.subscribe(callback.update, try_to_connect=True)
- second_connectivities = callback.block_until_connectivities_satisfy(
- lambda connectivities: 2 <= len(connectivities))
- # Wait for a connection that will never happen.
- time.sleep(test_constants.SHORT_TIMEOUT)
- third_connectivities = callback.connectivities()
- channel.unsubscribe(callback.update)
- fourth_connectivities = callback.connectivities()
- channel.unsubscribe(callback.update)
- fifth_connectivities = callback.connectivities()
-
- self.assertSequenceEqual(
- (grpc.ChannelConnectivity.IDLE,), first_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.READY, second_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.READY, third_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.READY, fourth_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.READY, fifth_connectivities)
-
- def test_immediately_connectable_channel_connectivity(self):
- thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
- server = grpc.server(thread_pool)
- port = server.add_insecure_port('[::]:0')
- server.start()
- first_callback = _Callback()
- second_callback = _Callback()
-
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- channel.subscribe(first_callback.update, try_to_connect=False)
- first_connectivities = first_callback.block_until_connectivities_satisfy(
- bool)
- # Wait for a connection that will never happen because try_to_connect=True
- # has not yet been passed.
- time.sleep(test_constants.SHORT_TIMEOUT)
- second_connectivities = first_callback.connectivities()
- channel.subscribe(second_callback.update, try_to_connect=True)
- third_connectivities = first_callback.block_until_connectivities_satisfy(
- lambda connectivities: 2 <= len(connectivities))
- fourth_connectivities = second_callback.block_until_connectivities_satisfy(
- bool)
- # Wait for a connection that will happen (or may already have happened).
- first_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
- second_callback.block_until_connectivities_satisfy(_ready_in_connectivities)
- del channel
-
- self.assertSequenceEqual(
- (grpc.ChannelConnectivity.IDLE,), first_connectivities)
- self.assertSequenceEqual(
- (grpc.ChannelConnectivity.IDLE,), second_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.TRANSIENT_FAILURE, third_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.SHUTDOWN, third_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.TRANSIENT_FAILURE,
- fourth_connectivities)
- self.assertNotIn(
- grpc.ChannelConnectivity.SHUTDOWN, fourth_connectivities)
- self.assertFalse(thread_pool.was_used())
-
- def test_reachable_then_unreachable_channel_connectivity(self):
- thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
- server = grpc.server(thread_pool)
- port = server.add_insecure_port('[::]:0')
- server.start()
- callback = _Callback()
-
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- channel.subscribe(callback.update, try_to_connect=True)
- callback.block_until_connectivities_satisfy(_ready_in_connectivities)
- # Now take down the server and confirm that channel readiness is repudiated.
- server.stop(None)
- callback.block_until_connectivities_satisfy(_last_connectivity_is_not_ready)
- channel.unsubscribe(callback.update)
- self.assertFalse(thread_pool.was_used())
+ def test_lonely_channel_connectivity(self):
+ callback = _Callback()
+
+ channel = grpc.insecure_channel('localhost:12345')
+ channel.subscribe(callback.update, try_to_connect=False)
+ first_connectivities = callback.block_until_connectivities_satisfy(bool)
+ channel.subscribe(callback.update, try_to_connect=True)
+ second_connectivities = callback.block_until_connectivities_satisfy(
+ lambda connectivities: 2 <= len(connectivities))
+ # Wait for a connection that will never happen.
+ time.sleep(test_constants.SHORT_TIMEOUT)
+ third_connectivities = callback.connectivities()
+ channel.unsubscribe(callback.update)
+ fourth_connectivities = callback.connectivities()
+ channel.unsubscribe(callback.update)
+ fifth_connectivities = callback.connectivities()
+
+ self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
+ first_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.READY, second_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.READY, third_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.READY, fourth_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.READY, fifth_connectivities)
+
+ def test_immediately_connectable_channel_connectivity(self):
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = grpc.server(thread_pool)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ first_callback = _Callback()
+ second_callback = _Callback()
+
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ channel.subscribe(first_callback.update, try_to_connect=False)
+ first_connectivities = first_callback.block_until_connectivities_satisfy(
+ bool)
+ # Wait for a connection that will never happen because try_to_connect=True
+ # has not yet been passed.
+ time.sleep(test_constants.SHORT_TIMEOUT)
+ second_connectivities = first_callback.connectivities()
+ channel.subscribe(second_callback.update, try_to_connect=True)
+ third_connectivities = first_callback.block_until_connectivities_satisfy(
+ lambda connectivities: 2 <= len(connectivities))
+ fourth_connectivities = second_callback.block_until_connectivities_satisfy(
+ bool)
+ # Wait for a connection that will happen (or may already have happened).
+ first_callback.block_until_connectivities_satisfy(
+ _ready_in_connectivities)
+ second_callback.block_until_connectivities_satisfy(
+ _ready_in_connectivities)
+ del channel
+
+ self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
+ first_connectivities)
+ self.assertSequenceEqual((grpc.ChannelConnectivity.IDLE,),
+ second_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ third_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN,
+ third_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.TRANSIENT_FAILURE,
+ fourth_connectivities)
+ self.assertNotIn(grpc.ChannelConnectivity.SHUTDOWN,
+ fourth_connectivities)
+ self.assertFalse(thread_pool.was_used())
+
+ def test_reachable_then_unreachable_channel_connectivity(self):
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = grpc.server(thread_pool)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ callback = _Callback()
+
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ channel.subscribe(callback.update, try_to_connect=True)
+ callback.block_until_connectivities_satisfy(_ready_in_connectivities)
+ # Now take down the server and confirm that channel readiness is repudiated.
+ server.stop(None)
+ callback.block_until_connectivities_satisfy(
+ _last_connectivity_is_not_ready)
+ channel.unsubscribe(callback.update)
+ self.assertFalse(thread_pool.was_used())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
index 46a964db8c..2d1b63e15f 100644
--- a/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
+++ b/src/python/grpcio_tests/tests/unit/_channel_ready_future_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc.channel_ready_future."""
import threading
@@ -39,65 +38,66 @@ from tests.unit import _thread_pool
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._value = None
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
- def accept_value(self, value):
- with self._condition:
- self._value = value
- self._condition.notify_all()
+ def accept_value(self, value):
+ with self._condition:
+ self._value = value
+ self._condition.notify_all()
- def block_until_called(self):
- with self._condition:
- while self._value is None:
- self._condition.wait()
- return self._value
+ def block_until_called(self):
+ with self._condition:
+ while self._value is None:
+ self._condition.wait()
+ return self._value
class ChannelReadyFutureTest(unittest.TestCase):
- def test_lonely_channel_connectivity(self):
- channel = grpc.insecure_channel('localhost:12345')
- callback = _Callback()
-
- ready_future = grpc.channel_ready_future(channel)
- ready_future.add_done_callback(callback.accept_value)
- with self.assertRaises(grpc.FutureTimeoutError):
- ready_future.result(timeout=test_constants.SHORT_TIMEOUT)
- self.assertFalse(ready_future.cancelled())
- self.assertFalse(ready_future.done())
- self.assertTrue(ready_future.running())
- ready_future.cancel()
- value_passed_to_callback = callback.block_until_called()
- self.assertIs(ready_future, value_passed_to_callback)
- self.assertTrue(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
-
- def test_immediately_connectable_channel_connectivity(self):
- thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
- server = grpc.server(thread_pool)
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- callback = _Callback()
-
- ready_future = grpc.channel_ready_future(channel)
- ready_future.add_done_callback(callback.accept_value)
- self.assertIsNone(ready_future.result(timeout=test_constants.LONG_TIMEOUT))
- value_passed_to_callback = callback.block_until_called()
- self.assertIs(ready_future, value_passed_to_callback)
- self.assertFalse(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
- # Cancellation after maturity has no effect.
- ready_future.cancel()
- self.assertFalse(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
- self.assertFalse(thread_pool.was_used())
+ def test_lonely_channel_connectivity(self):
+ channel = grpc.insecure_channel('localhost:12345')
+ callback = _Callback()
+
+ ready_future = grpc.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ with self.assertRaises(grpc.FutureTimeoutError):
+ ready_future.result(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertFalse(ready_future.cancelled())
+ self.assertFalse(ready_future.done())
+ self.assertTrue(ready_future.running())
+ ready_future.cancel()
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertTrue(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+
+ def test_immediately_connectable_channel_connectivity(self):
+ thread_pool = _thread_pool.RecordingThreadPool(max_workers=None)
+ server = grpc.server(thread_pool)
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ callback = _Callback()
+
+ ready_future = grpc.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ self.assertIsNone(
+ ready_future.result(timeout=test_constants.LONG_TIMEOUT))
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+ # Cancellation after maturity has no effect.
+ ready_future.cancel()
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+ self.assertFalse(thread_pool.was_used())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_compression_test.py b/src/python/grpcio_tests/tests/unit/_compression_test.py
index 4d3f02e917..7dd944e600 100644
--- a/src/python/grpcio_tests/tests/unit/_compression_test.py
+++ b/src/python/grpcio_tests/tests/unit/_compression_test.py
@@ -42,93 +42,96 @@ _STREAM_STREAM = '/test/StreamStream'
def handle_unary(request, servicer_context):
- servicer_context.send_initial_metadata([
- ('grpc-internal-encoding-request', 'gzip')])
- return request
+ servicer_context.send_initial_metadata(
+ [('grpc-internal-encoding-request', 'gzip')])
+ return request
def handle_stream(request_iterator, servicer_context):
- # TODO(issue:#6891) We should be able to remove this loop,
- # and replace with return; yield
- servicer_context.send_initial_metadata([
- ('grpc-internal-encoding-request', 'gzip')])
- for request in request_iterator:
- yield request
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ servicer_context.send_initial_metadata(
+ [('grpc-internal-encoding-request', 'gzip')])
+ for request in request_iterator:
+ yield request
class _MethodHandler(grpc.RpcMethodHandler):
- def __init__(self, request_streaming, response_streaming):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = None
- self.response_serializer = None
- self.unary_unary = None
- self.unary_stream = None
- self.stream_unary = None
- self.stream_stream = None
- if self.request_streaming and self.response_streaming:
- self.stream_stream = lambda x, y: handle_stream(x, y)
- elif not self.request_streaming and not self.response_streaming:
- self.unary_unary = lambda x, y: handle_unary(x, y)
+ def __init__(self, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = lambda x, y: handle_stream(x, y)
+ elif not self.request_streaming and not self.response_streaming:
+ self.unary_unary = lambda x, y: handle_unary(x, y)
class _GenericHandler(grpc.GenericRpcHandler):
- def service(self, handler_call_details):
- if handler_call_details.method == _UNARY_UNARY:
- return _MethodHandler(False, False)
- elif handler_call_details.method == _STREAM_STREAM:
- return _MethodHandler(True, True)
- else:
- return None
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True)
+ else:
+ return None
class CompressionTest(unittest.TestCase):
- def setUp(self):
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(
- self._server_pool, handlers=(_GenericHandler(),))
- self._port = self._server.add_insecure_port('[::]:0')
- self._server.start()
-
- def testUnary(self):
- request = b'\x00' * 100
-
- # Client -> server compressed through default client channel compression
- # settings. Server -> client compressed via server-side metadata setting.
- # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
- # literal with proper use of the public API.
- compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
- options=[('grpc.default_compression_algorithm', 1)])
- multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
- response = multi_callable(request)
- self.assertEqual(request, response)
-
- # Client -> server compressed through client metadata setting. Server ->
- # client compressed via server-side metadata setting.
- # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer
- # literal with proper use of the public API.
- uncompressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
- options=[('grpc.default_compression_algorithm', 0)])
- multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
- response = multi_callable(request, metadata=[
- ('grpc-internal-encoding-request', 'gzip')])
- self.assertEqual(request, response)
-
- def testStreaming(self):
- request = b'\x00' * 100
-
- # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
- # literal with proper use of the public API.
- compressed_channel = grpc.insecure_channel('localhost:%d' % self._port,
- options=[('grpc.default_compression_algorithm', 1)])
- multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
- call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))
- for response in call:
- self.assertEqual(request, response)
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ self._server_pool, handlers=(_GenericHandler(),))
+ self._port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ def testUnary(self):
+ request = b'\x00' * 100
+
+ # Client -> server compressed through default client channel compression
+ # settings. Server -> client compressed via server-side metadata setting.
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
+ # literal with proper use of the public API.
+ compressed_channel = grpc.insecure_channel(
+ 'localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 1)])
+ multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
+ response = multi_callable(request)
+ self.assertEqual(request, response)
+
+ # Client -> server compressed through client metadata setting. Server ->
+ # client compressed via server-side metadata setting.
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "0" integer
+ # literal with proper use of the public API.
+ uncompressed_channel = grpc.insecure_channel(
+ 'localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 0)])
+ multi_callable = compressed_channel.unary_unary(_UNARY_UNARY)
+ response = multi_callable(
+ request, metadata=[('grpc-internal-encoding-request', 'gzip')])
+ self.assertEqual(request, response)
+
+ def testStreaming(self):
+ request = b'\x00' * 100
+
+ # TODO(https://github.com/grpc/grpc/issues/4078): replace the "1" integer
+ # literal with proper use of the public API.
+ compressed_channel = grpc.insecure_channel(
+ 'localhost:%d' % self._port,
+ options=[('grpc.default_compression_algorithm', 1)])
+ multi_callable = compressed_channel.stream_stream(_STREAM_STREAM)
+ call = multi_callable(iter([request] * test_constants.STREAM_LENGTH))
+ for response in call:
+ self.assertEqual(request, response)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_credentials_test.py b/src/python/grpcio_tests/tests/unit/_credentials_test.py
index 87af85a0b9..21bf29789a 100644
--- a/src/python/grpcio_tests/tests/unit/_credentials_test.py
+++ b/src/python/grpcio_tests/tests/unit/_credentials_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of credentials."""
import unittest
@@ -36,37 +35,38 @@ import grpc
class CredentialsTest(unittest.TestCase):
- def test_call_credentials_composition(self):
- first = grpc.access_token_call_credentials('abc')
- second = grpc.access_token_call_credentials('def')
- third = grpc.access_token_call_credentials('ghi')
+ def test_call_credentials_composition(self):
+ first = grpc.access_token_call_credentials('abc')
+ second = grpc.access_token_call_credentials('def')
+ third = grpc.access_token_call_credentials('ghi')
+
+ first_and_second = grpc.composite_call_credentials(first, second)
+ first_second_and_third = grpc.composite_call_credentials(first, second,
+ third)
- first_and_second = grpc.composite_call_credentials(first, second)
- first_second_and_third = grpc.composite_call_credentials(
- first, second, third)
-
- self.assertIsInstance(first_and_second, grpc.CallCredentials)
- self.assertIsInstance(first_second_and_third, grpc.CallCredentials)
+ self.assertIsInstance(first_and_second, grpc.CallCredentials)
+ self.assertIsInstance(first_second_and_third, grpc.CallCredentials)
- def test_channel_credentials_composition(self):
- first_call_credentials = grpc.access_token_call_credentials('abc')
- second_call_credentials = grpc.access_token_call_credentials('def')
- third_call_credentials = grpc.access_token_call_credentials('ghi')
- channel_credentials = grpc.ssl_channel_credentials()
+ def test_channel_credentials_composition(self):
+ first_call_credentials = grpc.access_token_call_credentials('abc')
+ second_call_credentials = grpc.access_token_call_credentials('def')
+ third_call_credentials = grpc.access_token_call_credentials('ghi')
+ channel_credentials = grpc.ssl_channel_credentials()
- channel_and_first = grpc.composite_channel_credentials(
- channel_credentials, first_call_credentials)
- channel_first_and_second = grpc.composite_channel_credentials(
- channel_credentials, first_call_credentials, second_call_credentials)
- channel_first_second_and_third = grpc.composite_channel_credentials(
- channel_credentials, first_call_credentials, second_call_credentials,
- third_call_credentials)
+ channel_and_first = grpc.composite_channel_credentials(
+ channel_credentials, first_call_credentials)
+ channel_first_and_second = grpc.composite_channel_credentials(
+ channel_credentials, first_call_credentials,
+ second_call_credentials)
+ channel_first_second_and_third = grpc.composite_channel_credentials(
+ channel_credentials, first_call_credentials,
+ second_call_credentials, third_call_credentials)
- self.assertIsInstance(channel_and_first, grpc.ChannelCredentials)
- self.assertIsInstance(channel_first_and_second, grpc.ChannelCredentials)
- self.assertIsInstance(
- channel_first_second_and_third, grpc.ChannelCredentials)
+ self.assertIsInstance(channel_and_first, grpc.ChannelCredentials)
+ self.assertIsInstance(channel_first_and_second, grpc.ChannelCredentials)
+ self.assertIsInstance(channel_first_second_and_third,
+ grpc.ChannelCredentials)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
index 20115fb22c..d77f5ecb27 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_cancel_many_calls_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test making many calls and immediately cancelling most of them."""
import threading
@@ -51,173 +50,178 @@ _SUCCESS_CALL_FRACTION = 1.0 / 8.0
class _State(object):
- def __init__(self):
- self.condition = threading.Condition()
- self.handlers_released = False
- self.parked_handlers = 0
- self.handled_rpcs = 0
+ def __init__(self):
+ self.condition = threading.Condition()
+ self.handlers_released = False
+ self.parked_handlers = 0
+ self.handled_rpcs = 0
def _is_cancellation_event(event):
- return (
- event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
- event.batch_operations[0].received_cancelled)
+ return (event.tag is _RECEIVE_CLOSE_ON_SERVER_TAG and
+ event.batch_operations[0].received_cancelled)
class _Handler(object):
- def __init__(self, state, completion_queue, rpc_event):
- self._state = state
- self._lock = threading.Lock()
- self._completion_queue = completion_queue
- self._call = rpc_event.operation_call
-
- def __call__(self):
- with self._state.condition:
- self._state.parked_handlers += 1
- if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
- self._state.condition.notify_all()
- while not self._state.handlers_released:
- self._state.condition.wait()
-
- with self._lock:
- self._call.start_server_batch(
- cygrpc.Operations(
- (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
- _RECEIVE_CLOSE_ON_SERVER_TAG)
- self._call.start_server_batch(
- cygrpc.Operations((cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
- _RECEIVE_MESSAGE_TAG)
- first_event = self._completion_queue.poll()
- if _is_cancellation_event(first_event):
- self._completion_queue.poll()
- else:
- with self._lock:
- operations = (
- cygrpc.operation_send_initial_metadata(
- _EMPTY_METADATA, _EMPTY_FLAGS),
- cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
- cygrpc.operation_send_status_from_server(
- _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
- _EMPTY_FLAGS),
- )
- self._call.start_server_batch(
- cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
- self._completion_queue.poll()
- self._completion_queue.poll()
+ def __init__(self, state, completion_queue, rpc_event):
+ self._state = state
+ self._lock = threading.Lock()
+ self._completion_queue = completion_queue
+ self._call = rpc_event.operation_call
+
+ def __call__(self):
+ with self._state.condition:
+ self._state.parked_handlers += 1
+ if self._state.parked_handlers == test_constants.THREAD_CONCURRENCY:
+ self._state.condition.notify_all()
+ while not self._state.handlers_released:
+ self._state.condition.wait()
+
+ with self._lock:
+ self._call.start_server_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),)),
+ _RECEIVE_CLOSE_ON_SERVER_TAG)
+ self._call.start_server_batch(
+ cygrpc.Operations(
+ (cygrpc.operation_receive_message(_EMPTY_FLAGS),)),
+ _RECEIVE_MESSAGE_TAG)
+ first_event = self._completion_queue.poll()
+ if _is_cancellation_event(first_event):
+ self._completion_queue.poll()
+ else:
+ with self._lock:
+ operations = (
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
+ _EMPTY_FLAGS),
+ cygrpc.operation_send_message(b'\x79\x57', _EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ _EMPTY_METADATA, cygrpc.StatusCode.ok, b'test details!',
+ _EMPTY_FLAGS),)
+ self._call.start_server_batch(
+ cygrpc.Operations(operations), _SERVER_COMPLETE_CALL_TAG)
+ self._completion_queue.poll()
+ self._completion_queue.poll()
def _serve(state, server, server_completion_queue, thread_pool):
- for _ in range(test_constants.RPC_CONCURRENCY):
- call_completion_queue = cygrpc.CompletionQueue()
- server.request_call(
- call_completion_queue, server_completion_queue, _REQUEST_CALL_TAG)
- rpc_event = server_completion_queue.poll()
- thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
- with state.condition:
- state.handled_rpcs += 1
- if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
- state.condition.notify_all()
- server_completion_queue.poll()
+ for _ in range(test_constants.RPC_CONCURRENCY):
+ call_completion_queue = cygrpc.CompletionQueue()
+ server.request_call(call_completion_queue, server_completion_queue,
+ _REQUEST_CALL_TAG)
+ rpc_event = server_completion_queue.poll()
+ thread_pool.submit(_Handler(state, call_completion_queue, rpc_event))
+ with state.condition:
+ state.handled_rpcs += 1
+ if test_constants.RPC_CONCURRENCY <= state.handled_rpcs:
+ state.condition.notify_all()
+ server_completion_queue.poll()
class _QueueDriver(object):
- def __init__(self, condition, completion_queue, due):
- self._condition = condition
- self._completion_queue = completion_queue
- self._due = due
- self._events = []
- self._returned = False
-
- def start(self):
- def in_thread():
- while True:
- event = self._completion_queue.poll()
+ def __init__(self, condition, completion_queue, due):
+ self._condition = condition
+ self._completion_queue = completion_queue
+ self._due = due
+ self._events = []
+ self._returned = False
+
+ def start(self):
+
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._due.remove(event.tag)
+ self._condition.notify_all()
+ if not self._due:
+ self._returned = True
+ return
+
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def events(self, at_least):
with self._condition:
- self._events.append(event)
- self._due.remove(event.tag)
- self._condition.notify_all()
- if not self._due:
- self._returned = True
- return
- thread = threading.Thread(target=in_thread)
- thread.start()
-
- def events(self, at_least):
- with self._condition:
- while len(self._events) < at_least:
- self._condition.wait()
- return tuple(self._events)
+ while len(self._events) < at_least:
+ self._condition.wait()
+ return tuple(self._events)
class CancelManyCallsTest(unittest.TestCase):
- def testCancelManyCalls(self):
- server_thread_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
-
- server_completion_queue = cygrpc.CompletionQueue()
- server = cygrpc.Server(cygrpc.ChannelArgs([]))
- server.register_completion_queue(server_completion_queue)
- port = server.add_http2_port(b'[::]:0')
- server.start()
- channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
- cygrpc.ChannelArgs([]))
-
- state = _State()
-
- server_thread_args = (
- state, server, server_completion_queue, server_thread_pool,)
- server_thread = threading.Thread(target=_serve, args=server_thread_args)
- server_thread.start()
-
- client_condition = threading.Condition()
- client_due = set()
- client_completion_queue = cygrpc.CompletionQueue()
- client_driver = _QueueDriver(
- client_condition, client_completion_queue, client_due)
- client_driver.start()
-
- with client_condition:
- client_calls = []
- for index in range(test_constants.RPC_CONCURRENCY):
- client_call = channel.create_call(
- None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
- _INFINITE_FUTURE)
- operations = (
- cygrpc.operation_send_initial_metadata(
- _EMPTY_METADATA, _EMPTY_FLAGS),
- cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
- )
- tag = 'client_complete_call_{0:04d}_tag'.format(index)
- client_call.start_client_batch(cygrpc.Operations(operations), tag)
- client_due.add(tag)
- client_calls.append(client_call)
-
- with state.condition:
- while True:
- if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
- state.condition.wait()
- elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
- state.condition.wait()
- else:
- state.handlers_released = True
- state.condition.notify_all()
- break
-
- client_driver.events(
- test_constants.RPC_CONCURRENCY * _SUCCESS_CALL_FRACTION)
- with client_condition:
- for client_call in client_calls:
- client_call.cancel()
-
- with state.condition:
- server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
+ def testCancelManyCalls(self):
+ server_thread_pool = logging_pool.pool(
+ test_constants.THREAD_CONCURRENCY)
+
+ server_completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ server.register_completion_queue(server_completion_queue)
+ port = server.add_http2_port(b'[::]:0')
+ server.start()
+ channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
+ cygrpc.ChannelArgs([]))
+
+ state = _State()
+
+ server_thread_args = (
+ state,
+ server,
+ server_completion_queue,
+ server_thread_pool,)
+ server_thread = threading.Thread(target=_serve, args=server_thread_args)
+ server_thread.start()
+
+ client_condition = threading.Condition()
+ client_due = set()
+ client_completion_queue = cygrpc.CompletionQueue()
+ client_driver = _QueueDriver(client_condition, client_completion_queue,
+ client_due)
+ client_driver.start()
+
+ with client_condition:
+ client_calls = []
+ for index in range(test_constants.RPC_CONCURRENCY):
+ client_call = channel.create_call(
+ None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies',
+ None, _INFINITE_FUTURE)
+ operations = (
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
+ _EMPTY_FLAGS),
+ cygrpc.operation_send_message(b'\x45\x56', _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),)
+ tag = 'client_complete_call_{0:04d}_tag'.format(index)
+ client_call.start_client_batch(
+ cygrpc.Operations(operations), tag)
+ client_due.add(tag)
+ client_calls.append(client_call)
+
+ with state.condition:
+ while True:
+ if state.parked_handlers < test_constants.THREAD_CONCURRENCY:
+ state.condition.wait()
+ elif state.handled_rpcs < test_constants.RPC_CONCURRENCY:
+ state.condition.wait()
+ else:
+ state.handlers_released = True
+ state.condition.notify_all()
+ break
+
+ client_driver.events(test_constants.RPC_CONCURRENCY *
+ _SUCCESS_CALL_FRACTION)
+ with client_condition:
+ for client_call in client_calls:
+ client_call.cancel()
+
+ with state.condition:
+ server.shutdown(server_completion_queue, _SERVER_SHUTDOWN_TAG)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
index f9c8a3ac62..0ca06868b2 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_channel_test.py
@@ -37,46 +37,49 @@ from tests.unit.framework.common import test_constants
def _channel_and_completion_queue():
- channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(()))
- completion_queue = cygrpc.CompletionQueue()
- return channel, completion_queue
+ channel = cygrpc.Channel(b'localhost:54321', cygrpc.ChannelArgs(()))
+ completion_queue = cygrpc.CompletionQueue()
+ return channel, completion_queue
def _connectivity_loop(channel, completion_queue):
- for _ in range(100):
- connectivity = channel.check_connectivity_state(True)
- channel.watch_connectivity_state(
- connectivity, cygrpc.Timespec(time.time() + 0.2), completion_queue,
- None)
- completion_queue.poll(deadline=cygrpc.Timespec(float('+inf')))
+ for _ in range(100):
+ connectivity = channel.check_connectivity_state(True)
+ channel.watch_connectivity_state(connectivity,
+ cygrpc.Timespec(time.time() + 0.2),
+ completion_queue, None)
+ completion_queue.poll(deadline=cygrpc.Timespec(float('+inf')))
def _create_loop_destroy():
- channel, completion_queue = _channel_and_completion_queue()
- _connectivity_loop(channel, completion_queue)
- completion_queue.shutdown()
+ channel, completion_queue = _channel_and_completion_queue()
+ _connectivity_loop(channel, completion_queue)
+ completion_queue.shutdown()
def _in_parallel(behavior, arguments):
- threads = tuple(
- threading.Thread(target=behavior, args=arguments)
- for _ in range(test_constants.THREAD_CONCURRENCY))
- for thread in threads:
- thread.start()
- for thread in threads:
- thread.join()
+ threads = tuple(
+ threading.Thread(
+ target=behavior, args=arguments)
+ for _ in range(test_constants.THREAD_CONCURRENCY))
+ for thread in threads:
+ thread.start()
+ for thread in threads:
+ thread.join()
class ChannelTest(unittest.TestCase):
- def test_single_channel_lonely_connectivity(self):
- channel, completion_queue = _channel_and_completion_queue()
- _in_parallel(_connectivity_loop, (channel, completion_queue,))
- completion_queue.shutdown()
+ def test_single_channel_lonely_connectivity(self):
+ channel, completion_queue = _channel_and_completion_queue()
+ _in_parallel(_connectivity_loop, (
+ channel,
+ completion_queue,))
+ completion_queue.shutdown()
- def test_multiple_channels_lonely_connectivity(self):
- _in_parallel(_create_loop_destroy, ())
+ def test_multiple_channels_lonely_connectivity(self):
+ _in_parallel(_create_loop_destroy, ())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
index 2ae5285232..9fbfcbb9c0 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/_read_some_but_not_all_responses_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test a corner-case at the level of the Cython API."""
import threading
@@ -41,212 +40,221 @@ _EMPTY_METADATA = cygrpc.Metadata(())
class _ServerDriver(object):
- def __init__(self, completion_queue, shutdown_tag):
- self._condition = threading.Condition()
- self._completion_queue = completion_queue
- self._shutdown_tag = shutdown_tag
- self._events = []
- self._saw_shutdown_tag = False
-
- def start(self):
- def in_thread():
- while True:
- event = self._completion_queue.poll()
+ def __init__(self, completion_queue, shutdown_tag):
+ self._condition = threading.Condition()
+ self._completion_queue = completion_queue
+ self._shutdown_tag = shutdown_tag
+ self._events = []
+ self._saw_shutdown_tag = False
+
+ def start(self):
+
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._condition.notify()
+ if event.tag is self._shutdown_tag:
+ self._saw_shutdown_tag = True
+ break
+
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._saw_shutdown_tag
+
+ def first_event(self):
+ with self._condition:
+ while not self._events:
+ self._condition.wait()
+ return self._events[0]
+
+ def events(self):
with self._condition:
- self._events.append(event)
- self._condition.notify()
- if event.tag is self._shutdown_tag:
- self._saw_shutdown_tag = True
- break
- thread = threading.Thread(target=in_thread)
- thread.start()
-
- def done(self):
- with self._condition:
- return self._saw_shutdown_tag
-
- def first_event(self):
- with self._condition:
- while not self._events:
- self._condition.wait()
- return self._events[0]
-
- def events(self):
- with self._condition:
- while not self._saw_shutdown_tag:
- self._condition.wait()
- return tuple(self._events)
+ while not self._saw_shutdown_tag:
+ self._condition.wait()
+ return tuple(self._events)
class _QueueDriver(object):
- def __init__(self, condition, completion_queue, due):
- self._condition = condition
- self._completion_queue = completion_queue
- self._due = due
- self._events = []
- self._returned = False
-
- def start(self):
- def in_thread():
- while True:
- event = self._completion_queue.poll()
+ def __init__(self, condition, completion_queue, due):
+ self._condition = condition
+ self._completion_queue = completion_queue
+ self._due = due
+ self._events = []
+ self._returned = False
+
+ def start(self):
+
+ def in_thread():
+ while True:
+ event = self._completion_queue.poll()
+ with self._condition:
+ self._events.append(event)
+ self._due.remove(event.tag)
+ self._condition.notify_all()
+ if not self._due:
+ self._returned = True
+ return
+
+ thread = threading.Thread(target=in_thread)
+ thread.start()
+
+ def done(self):
+ with self._condition:
+ return self._returned
+
+ def event_with_tag(self, tag):
+ with self._condition:
+ while True:
+ for event in self._events:
+ if event.tag is tag:
+ return event
+ self._condition.wait()
+
+ def events(self):
with self._condition:
- self._events.append(event)
- self._due.remove(event.tag)
- self._condition.notify_all()
- if not self._due:
- self._returned = True
- return
- thread = threading.Thread(target=in_thread)
- thread.start()
-
- def done(self):
- with self._condition:
- return self._returned
-
- def event_with_tag(self, tag):
- with self._condition:
- while True:
- for event in self._events:
- if event.tag is tag:
- return event
- self._condition.wait()
-
- def events(self):
- with self._condition:
- while not self._returned:
- self._condition.wait()
- return tuple(self._events)
+ while not self._returned:
+ self._condition.wait()
+ return tuple(self._events)
class ReadSomeButNotAllResponsesTest(unittest.TestCase):
- def testReadSomeButNotAllResponses(self):
- server_completion_queue = cygrpc.CompletionQueue()
- server = cygrpc.Server(cygrpc.ChannelArgs([]))
- server.register_completion_queue(server_completion_queue)
- port = server.add_http2_port(b'[::]:0')
- server.start()
- channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
- cygrpc.ChannelArgs([]))
-
- server_shutdown_tag = 'server_shutdown_tag'
- server_driver = _ServerDriver(server_completion_queue, server_shutdown_tag)
- server_driver.start()
-
- client_condition = threading.Condition()
- client_due = set()
- client_completion_queue = cygrpc.CompletionQueue()
- client_driver = _QueueDriver(
- client_condition, client_completion_queue, client_due)
- client_driver.start()
-
- server_call_condition = threading.Condition()
- server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
- server_send_first_message_tag = 'server_send_first_message_tag'
- server_send_second_message_tag = 'server_send_second_message_tag'
- server_complete_rpc_tag = 'server_complete_rpc_tag'
- server_call_due = set((
- server_send_initial_metadata_tag,
- server_send_first_message_tag,
- server_send_second_message_tag,
- server_complete_rpc_tag,
- ))
- server_call_completion_queue = cygrpc.CompletionQueue()
- server_call_driver = _QueueDriver(
- server_call_condition, server_call_completion_queue, server_call_due)
- server_call_driver.start()
-
- server_rpc_tag = 'server_rpc_tag'
- request_call_result = server.request_call(
- server_call_completion_queue, server_completion_queue, server_rpc_tag)
-
- client_call = channel.create_call(
- None, _EMPTY_FLAGS, client_completion_queue, b'/twinkies', None,
- _INFINITE_FUTURE)
- client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
- client_complete_rpc_tag = 'client_complete_rpc_tag'
- with client_condition:
- client_receive_initial_metadata_start_batch_result = (
- client_call.start_client_batch(cygrpc.Operations([
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- ]), client_receive_initial_metadata_tag))
- client_due.add(client_receive_initial_metadata_tag)
- client_complete_rpc_start_batch_result = (
- client_call.start_client_batch(cygrpc.Operations([
- cygrpc.operation_send_initial_metadata(
- _EMPTY_METADATA, _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
- ]), client_complete_rpc_tag))
- client_due.add(client_complete_rpc_tag)
-
- server_rpc_event = server_driver.first_event()
-
- with server_call_condition:
- server_send_initial_metadata_start_batch_result = (
- server_rpc_event.operation_call.start_server_batch([
- cygrpc.operation_send_initial_metadata(
- _EMPTY_METADATA, _EMPTY_FLAGS),
- ], server_send_initial_metadata_tag))
- server_send_first_message_start_batch_result = (
- server_rpc_event.operation_call.start_server_batch([
- cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
- ], server_send_first_message_tag))
- server_send_initial_metadata_event = server_call_driver.event_with_tag(
- server_send_initial_metadata_tag)
- server_send_first_message_event = server_call_driver.event_with_tag(
- server_send_first_message_tag)
- with server_call_condition:
- server_send_second_message_start_batch_result = (
- server_rpc_event.operation_call.start_server_batch([
- cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
- ], server_send_second_message_tag))
- server_complete_rpc_start_batch_result = (
- server_rpc_event.operation_call.start_server_batch([
- cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
- cygrpc.operation_send_status_from_server(
- cygrpc.Metadata(()), cygrpc.StatusCode.ok, b'test details',
- _EMPTY_FLAGS),
- ], server_complete_rpc_tag))
- server_send_second_message_event = server_call_driver.event_with_tag(
- server_send_second_message_tag)
- server_complete_rpc_event = server_call_driver.event_with_tag(
- server_complete_rpc_tag)
- server_call_driver.events()
-
- with client_condition:
- client_receive_first_message_tag = 'client_receive_first_message_tag'
- client_receive_first_message_start_batch_result = (
- client_call.start_client_batch(cygrpc.Operations([
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- ]), client_receive_first_message_tag))
- client_due.add(client_receive_first_message_tag)
- client_receive_first_message_event = client_driver.event_with_tag(
- client_receive_first_message_tag)
-
- client_call_cancel_result = client_call.cancel()
- client_driver.events()
-
- server.shutdown(server_completion_queue, server_shutdown_tag)
- server.cancel_all_calls()
- server_driver.events()
-
- self.assertEqual(cygrpc.CallError.ok, request_call_result)
- self.assertEqual(
- cygrpc.CallError.ok, server_send_initial_metadata_start_batch_result)
- self.assertEqual(
- cygrpc.CallError.ok, client_receive_initial_metadata_start_batch_result)
- self.assertEqual(
- cygrpc.CallError.ok, client_complete_rpc_start_batch_result)
- self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
- self.assertIs(server_rpc_tag, server_rpc_event.tag)
- self.assertEqual(
- cygrpc.CompletionType.operation_complete, server_rpc_event.type)
- self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call)
- self.assertEqual(0, len(server_rpc_event.batch_operations))
+ def testReadSomeButNotAllResponses(self):
+ server_completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ server.register_completion_queue(server_completion_queue)
+ port = server.add_http2_port(b'[::]:0')
+ server.start()
+ channel = cygrpc.Channel('localhost:{}'.format(port).encode(),
+ cygrpc.ChannelArgs([]))
+
+ server_shutdown_tag = 'server_shutdown_tag'
+ server_driver = _ServerDriver(server_completion_queue,
+ server_shutdown_tag)
+ server_driver.start()
+
+ client_condition = threading.Condition()
+ client_due = set()
+ client_completion_queue = cygrpc.CompletionQueue()
+ client_driver = _QueueDriver(client_condition, client_completion_queue,
+ client_due)
+ client_driver.start()
+
+ server_call_condition = threading.Condition()
+ server_send_initial_metadata_tag = 'server_send_initial_metadata_tag'
+ server_send_first_message_tag = 'server_send_first_message_tag'
+ server_send_second_message_tag = 'server_send_second_message_tag'
+ server_complete_rpc_tag = 'server_complete_rpc_tag'
+ server_call_due = set((
+ server_send_initial_metadata_tag,
+ server_send_first_message_tag,
+ server_send_second_message_tag,
+ server_complete_rpc_tag,))
+ server_call_completion_queue = cygrpc.CompletionQueue()
+ server_call_driver = _QueueDriver(server_call_condition,
+ server_call_completion_queue,
+ server_call_due)
+ server_call_driver.start()
+
+ server_rpc_tag = 'server_rpc_tag'
+ request_call_result = server.request_call(server_call_completion_queue,
+ server_completion_queue,
+ server_rpc_tag)
+
+ client_call = channel.create_call(None, _EMPTY_FLAGS,
+ client_completion_queue, b'/twinkies',
+ None, _INFINITE_FUTURE)
+ client_receive_initial_metadata_tag = 'client_receive_initial_metadata_tag'
+ client_complete_rpc_tag = 'client_complete_rpc_tag'
+ with client_condition:
+ client_receive_initial_metadata_start_batch_result = (
+ client_call.start_client_batch(
+ cygrpc.Operations([
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ ]), client_receive_initial_metadata_tag))
+ client_due.add(client_receive_initial_metadata_tag)
+ client_complete_rpc_start_batch_result = (
+ client_call.start_client_batch(
+ cygrpc.Operations([
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
+ _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS),
+ ]), client_complete_rpc_tag))
+ client_due.add(client_complete_rpc_tag)
+
+ server_rpc_event = server_driver.first_event()
+
+ with server_call_condition:
+ server_send_initial_metadata_start_batch_result = (
+ server_rpc_event.operation_call.start_server_batch([
+ cygrpc.operation_send_initial_metadata(_EMPTY_METADATA,
+ _EMPTY_FLAGS),
+ ], server_send_initial_metadata_tag))
+ server_send_first_message_start_batch_result = (
+ server_rpc_event.operation_call.start_server_batch([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ], server_send_first_message_tag))
+ server_send_initial_metadata_event = server_call_driver.event_with_tag(
+ server_send_initial_metadata_tag)
+ server_send_first_message_event = server_call_driver.event_with_tag(
+ server_send_first_message_tag)
+ with server_call_condition:
+ server_send_second_message_start_batch_result = (
+ server_rpc_event.operation_call.start_server_batch([
+ cygrpc.operation_send_message(b'\x07', _EMPTY_FLAGS),
+ ], server_send_second_message_tag))
+ server_complete_rpc_start_batch_result = (
+ server_rpc_event.operation_call.start_server_batch([
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ cygrpc.Metadata(()), cygrpc.StatusCode.ok,
+ b'test details', _EMPTY_FLAGS),
+ ], server_complete_rpc_tag))
+ server_send_second_message_event = server_call_driver.event_with_tag(
+ server_send_second_message_tag)
+ server_complete_rpc_event = server_call_driver.event_with_tag(
+ server_complete_rpc_tag)
+ server_call_driver.events()
+
+ with client_condition:
+ client_receive_first_message_tag = 'client_receive_first_message_tag'
+ client_receive_first_message_start_batch_result = (
+ client_call.start_client_batch(
+ cygrpc.Operations([
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ]), client_receive_first_message_tag))
+ client_due.add(client_receive_first_message_tag)
+ client_receive_first_message_event = client_driver.event_with_tag(
+ client_receive_first_message_tag)
+
+ client_call_cancel_result = client_call.cancel()
+ client_driver.events()
+
+ server.shutdown(server_completion_queue, server_shutdown_tag)
+ server.cancel_all_calls()
+ server_driver.events()
+
+ self.assertEqual(cygrpc.CallError.ok, request_call_result)
+ self.assertEqual(cygrpc.CallError.ok,
+ server_send_initial_metadata_start_batch_result)
+ self.assertEqual(cygrpc.CallError.ok,
+ client_receive_initial_metadata_start_batch_result)
+ self.assertEqual(cygrpc.CallError.ok,
+ client_complete_rpc_start_batch_result)
+ self.assertEqual(cygrpc.CallError.ok, client_call_cancel_result)
+ self.assertIs(server_rpc_tag, server_rpc_event.tag)
+ self.assertEqual(cygrpc.CompletionType.operation_complete,
+ server_rpc_event.type)
+ self.assertIsInstance(server_rpc_event.operation_call, cygrpc.Call)
+ self.assertEqual(0, len(server_rpc_event.batch_operations))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
index 8dedebfabe..7aec316b95 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/cygrpc_test.py
@@ -37,399 +37,421 @@ from tests.unit._cython import test_utilities
from tests.unit import test_common
from tests.unit import resources
-
_SSL_HOST_OVERRIDE = b'foo.test.google.fr'
_CALL_CREDENTIALS_METADATA_KEY = 'call-creds-key'
_CALL_CREDENTIALS_METADATA_VALUE = 'call-creds-value'
_EMPTY_FLAGS = 0
+
def _metadata_plugin_callback(context, callback):
- callback(cygrpc.Metadata(
- [cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
- _CALL_CREDENTIALS_METADATA_VALUE)]),
- cygrpc.StatusCode.ok, b'')
+ callback(
+ cygrpc.Metadata([
+ cygrpc.Metadatum(_CALL_CREDENTIALS_METADATA_KEY,
+ _CALL_CREDENTIALS_METADATA_VALUE)
+ ]), cygrpc.StatusCode.ok, b'')
class TypeSmokeTest(unittest.TestCase):
- def testStringsInUtilitiesUpDown(self):
- self.assertEqual(0, cygrpc.StatusCode.ok)
- metadatum = cygrpc.Metadatum(b'a', b'b')
- self.assertEqual(b'a', metadatum.key)
- self.assertEqual(b'b', metadatum.value)
- metadata = cygrpc.Metadata([metadatum])
- self.assertEqual(1, len(metadata))
- self.assertEqual(metadatum.key, metadata[0].key)
-
- def testMetadataIteration(self):
- metadata = cygrpc.Metadata([
- cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
- iterator = iter(metadata)
- metadatum = next(iterator)
- self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, b'a')
- self.assertEqual(metadatum.value, b'b')
- metadatum = next(iterator)
- self.assertIsInstance(metadatum, cygrpc.Metadatum)
- self.assertEqual(metadatum.key, b'c')
- self.assertEqual(metadatum.value, b'd')
- with self.assertRaises(StopIteration):
- next(iterator)
-
- def testOperationsIteration(self):
- operations = cygrpc.Operations([
- cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
- iterator = iter(operations)
- operation = next(iterator)
- self.assertIsInstance(operation, cygrpc.Operation)
- # `Operation`s are write-only structures; can't directly debug anything out
- # of them. Just check that we stop iterating.
- with self.assertRaises(StopIteration):
- next(iterator)
-
- def testOperationFlags(self):
- operation = cygrpc.operation_send_message(b'asdf',
- cygrpc.WriteFlag.no_compress)
- self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags)
-
- def testTimespec(self):
- now = time.time()
- timespec = cygrpc.Timespec(now)
- self.assertAlmostEqual(now, float(timespec), places=8)
-
- def testCompletionQueueUpDown(self):
- completion_queue = cygrpc.CompletionQueue()
- del completion_queue
-
- def testServerUpDown(self):
- server = cygrpc.Server(cygrpc.ChannelArgs([]))
- del server
-
- def testChannelUpDown(self):
- channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([]))
- del channel
-
- def testCredentialsMetadataPluginUpDown(self):
- plugin = cygrpc.CredentialsMetadataPlugin(
- lambda ignored_a, ignored_b: None, b'')
- del plugin
-
- def testCallCredentialsFromPluginUpDown(self):
- plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback, b'')
- call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
- del plugin
- del call_credentials
-
- def testServerStartNoExplicitShutdown(self):
- server = cygrpc.Server(cygrpc.ChannelArgs([]))
- completion_queue = cygrpc.CompletionQueue()
- server.register_completion_queue(completion_queue)
- port = server.add_http2_port(b'[::]:0')
- self.assertIsInstance(port, int)
- server.start()
- del server
-
- def testServerStartShutdown(self):
- completion_queue = cygrpc.CompletionQueue()
- server = cygrpc.Server(cygrpc.ChannelArgs([]))
- server.add_http2_port(b'[::]:0')
- server.register_completion_queue(completion_queue)
- server.start()
- shutdown_tag = object()
- server.shutdown(completion_queue, shutdown_tag)
- event = completion_queue.poll()
- self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
- self.assertIs(shutdown_tag, event.tag)
- del server
- del completion_queue
+ def testStringsInUtilitiesUpDown(self):
+ self.assertEqual(0, cygrpc.StatusCode.ok)
+ metadatum = cygrpc.Metadatum(b'a', b'b')
+ self.assertEqual(b'a', metadatum.key)
+ self.assertEqual(b'b', metadatum.value)
+ metadata = cygrpc.Metadata([metadatum])
+ self.assertEqual(1, len(metadata))
+ self.assertEqual(metadatum.key, metadata[0].key)
+
+ def testMetadataIteration(self):
+ metadata = cygrpc.Metadata(
+ [cygrpc.Metadatum(b'a', b'b'), cygrpc.Metadatum(b'c', b'd')])
+ iterator = iter(metadata)
+ metadatum = next(iterator)
+ self.assertIsInstance(metadatum, cygrpc.Metadatum)
+ self.assertEqual(metadatum.key, b'a')
+ self.assertEqual(metadatum.value, b'b')
+ metadatum = next(iterator)
+ self.assertIsInstance(metadatum, cygrpc.Metadatum)
+ self.assertEqual(metadatum.key, b'c')
+ self.assertEqual(metadatum.value, b'd')
+ with self.assertRaises(StopIteration):
+ next(iterator)
+
+ def testOperationsIteration(self):
+ operations = cygrpc.Operations(
+ [cygrpc.operation_send_message(b'asdf', _EMPTY_FLAGS)])
+ iterator = iter(operations)
+ operation = next(iterator)
+ self.assertIsInstance(operation, cygrpc.Operation)
+ # `Operation`s are write-only structures; can't directly debug anything out
+ # of them. Just check that we stop iterating.
+ with self.assertRaises(StopIteration):
+ next(iterator)
+
+ def testOperationFlags(self):
+ operation = cygrpc.operation_send_message(b'asdf',
+ cygrpc.WriteFlag.no_compress)
+ self.assertEqual(cygrpc.WriteFlag.no_compress, operation.flags)
+
+ def testTimespec(self):
+ now = time.time()
+ timespec = cygrpc.Timespec(now)
+ self.assertAlmostEqual(now, float(timespec), places=8)
+
+ def testCompletionQueueUpDown(self):
+ completion_queue = cygrpc.CompletionQueue()
+ del completion_queue
+
+ def testServerUpDown(self):
+ server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ del server
+
+ def testChannelUpDown(self):
+ channel = cygrpc.Channel(b'[::]:0', cygrpc.ChannelArgs([]))
+ del channel
+
+ def testCredentialsMetadataPluginUpDown(self):
+ plugin = cygrpc.CredentialsMetadataPlugin(
+ lambda ignored_a, ignored_b: None, b'')
+ del plugin
+
+ def testCallCredentialsFromPluginUpDown(self):
+ plugin = cygrpc.CredentialsMetadataPlugin(_metadata_plugin_callback,
+ b'')
+ call_credentials = cygrpc.call_credentials_metadata_plugin(plugin)
+ del plugin
+ del call_credentials
+
+ def testServerStartNoExplicitShutdown(self):
+ server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ completion_queue = cygrpc.CompletionQueue()
+ server.register_completion_queue(completion_queue)
+ port = server.add_http2_port(b'[::]:0')
+ self.assertIsInstance(port, int)
+ server.start()
+ del server
+
+ def testServerStartShutdown(self):
+ completion_queue = cygrpc.CompletionQueue()
+ server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ server.add_http2_port(b'[::]:0')
+ server.register_completion_queue(completion_queue)
+ server.start()
+ shutdown_tag = object()
+ server.shutdown(completion_queue, shutdown_tag)
+ event = completion_queue.poll()
+ self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
+ self.assertIs(shutdown_tag, event.tag)
+ del server
+ del completion_queue
class ServerClientMixin(object):
- def setUpMixin(self, server_credentials, client_credentials, host_override):
- self.server_completion_queue = cygrpc.CompletionQueue()
- self.server = cygrpc.Server(cygrpc.ChannelArgs([]))
- self.server.register_completion_queue(self.server_completion_queue)
- if server_credentials:
- self.port = self.server.add_http2_port(b'[::]:0', server_credentials)
- else:
- self.port = self.server.add_http2_port(b'[::]:0')
- self.server.start()
- self.client_completion_queue = cygrpc.CompletionQueue()
- if client_credentials:
- client_channel_arguments = cygrpc.ChannelArgs([
- cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
- host_override)])
- self.client_channel = cygrpc.Channel(
- 'localhost:{}'.format(self.port).encode(), client_channel_arguments,
- client_credentials)
- else:
- self.client_channel = cygrpc.Channel(
- 'localhost:{}'.format(self.port).encode(), cygrpc.ChannelArgs([]))
- if host_override:
- self.host_argument = None # default host
- self.expected_host = host_override
- else:
- # arbitrary host name necessitating no further identification
- self.host_argument = b'hostess'
- self.expected_host = self.host_argument
-
- def tearDownMixin(self):
- del self.server
- del self.client_completion_queue
- del self.server_completion_queue
-
- def _perform_operations(self, operations, call, queue, deadline, description):
- """Perform the list of operations with given call, queue, and deadline.
+ def setUpMixin(self, server_credentials, client_credentials, host_override):
+ self.server_completion_queue = cygrpc.CompletionQueue()
+ self.server = cygrpc.Server(cygrpc.ChannelArgs([]))
+ self.server.register_completion_queue(self.server_completion_queue)
+ if server_credentials:
+ self.port = self.server.add_http2_port(b'[::]:0',
+ server_credentials)
+ else:
+ self.port = self.server.add_http2_port(b'[::]:0')
+ self.server.start()
+ self.client_completion_queue = cygrpc.CompletionQueue()
+ if client_credentials:
+ client_channel_arguments = cygrpc.ChannelArgs([
+ cygrpc.ChannelArg(cygrpc.ChannelArgKey.ssl_target_name_override,
+ host_override)
+ ])
+ self.client_channel = cygrpc.Channel(
+ 'localhost:{}'.format(self.port).encode(),
+ client_channel_arguments, client_credentials)
+ else:
+ self.client_channel = cygrpc.Channel(
+ 'localhost:{}'.format(self.port).encode(),
+ cygrpc.ChannelArgs([]))
+ if host_override:
+ self.host_argument = None # default host
+ self.expected_host = host_override
+ else:
+ # arbitrary host name necessitating no further identification
+ self.host_argument = b'hostess'
+ self.expected_host = self.host_argument
+
+ def tearDownMixin(self):
+ del self.server
+ del self.client_completion_queue
+ del self.server_completion_queue
+
+ def _perform_operations(self, operations, call, queue, deadline,
+ description):
+ """Perform the list of operations with given call, queue, and deadline.
Invocation errors are reported with as an exception with `description` in
the message. Performs the operations asynchronously, returning a future.
"""
- def performer():
- tag = object()
- try:
- call_result = call.start_client_batch(
- cygrpc.Operations(operations), tag)
- self.assertEqual(cygrpc.CallError.ok, call_result)
- event = queue.poll(deadline)
- self.assertEqual(cygrpc.CompletionType.operation_complete, event.type)
- self.assertTrue(event.success)
- self.assertIs(tag, event.tag)
- except Exception as error:
- raise Exception("Error in '{}': {}".format(description, error.message))
- return event
- return test_utilities.SimpleFuture(performer)
-
- def testEcho(self):
- DEADLINE = time.time()+5
- DEADLINE_TOLERANCE = 0.25
- CLIENT_METADATA_ASCII_KEY = b'key'
- CLIENT_METADATA_ASCII_VALUE = b'val'
- CLIENT_METADATA_BIN_KEY = b'key-bin'
- CLIENT_METADATA_BIN_VALUE = b'\0'*1000
- SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
- SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
- SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
- SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
- SERVER_STATUS_CODE = cygrpc.StatusCode.ok
- SERVER_STATUS_DETAILS = b'our work is never over'
- REQUEST = b'in death a member of project mayhem has a name'
- RESPONSE = b'his name is robert paulson'
- METHOD = b'twinkies'
-
- cygrpc_deadline = cygrpc.Timespec(DEADLINE)
-
- server_request_tag = object()
- request_call_result = self.server.request_call(
- self.server_completion_queue, self.server_completion_queue,
- server_request_tag)
-
- self.assertEqual(cygrpc.CallError.ok, request_call_result)
-
- client_call_tag = object()
- client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, self.host_argument,
- cygrpc_deadline)
- client_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
- CLIENT_METADATA_ASCII_VALUE),
- cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)])
- client_start_batch_result = client_call.start_client_batch([
- cygrpc.operation_send_initial_metadata(client_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
- cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
- cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
- ], client_call_tag)
- self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
- client_event_future = test_utilities.CompletionQueuePollFuture(
- self.client_completion_queue, cygrpc_deadline)
-
- request_event = self.server_completion_queue.poll(cygrpc_deadline)
- self.assertEqual(cygrpc.CompletionType.operation_complete,
- request_event.type)
- self.assertIsInstance(request_event.operation_call, cygrpc.Call)
- self.assertIs(server_request_tag, request_event.tag)
- self.assertEqual(0, len(request_event.batch_operations))
- self.assertTrue(
- test_common.metadata_transmitted(client_initial_metadata,
- request_event.request_metadata))
- self.assertEqual(METHOD, request_event.request_call_details.method)
- self.assertEqual(self.expected_host,
- request_event.request_call_details.host)
- self.assertLess(
- abs(DEADLINE - float(request_event.request_call_details.deadline)),
- DEADLINE_TOLERANCE)
-
- server_call_tag = object()
- server_call = request_event.operation_call
- server_initial_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
- SERVER_INITIAL_METADATA_VALUE)])
- server_trailing_metadata = cygrpc.Metadata([
- cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
- SERVER_TRAILING_METADATA_VALUE)])
- server_start_batch_result = server_call.start_server_batch([
- cygrpc.operation_send_initial_metadata(server_initial_metadata,
- _EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
- cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
- cygrpc.operation_send_status_from_server(
- server_trailing_metadata, SERVER_STATUS_CODE,
- SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
- ], server_call_tag)
- self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
-
- server_event = self.server_completion_queue.poll(cygrpc_deadline)
- client_event = client_event_future.result()
-
- self.assertEqual(6, len(client_event.batch_operations))
- found_client_op_types = set()
- for client_result in client_event.batch_operations:
- # we expect each op type to be unique
- self.assertNotIn(client_result.type, found_client_op_types)
- found_client_op_types.add(client_result.type)
- if client_result.type == cygrpc.OperationType.receive_initial_metadata:
- self.assertTrue(
- test_common.metadata_transmitted(server_initial_metadata,
- client_result.received_metadata))
- elif client_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(RESPONSE, client_result.received_message.bytes())
- elif client_result.type == cygrpc.OperationType.receive_status_on_client:
+
+ def performer():
+ tag = object()
+ try:
+ call_result = call.start_client_batch(
+ cygrpc.Operations(operations), tag)
+ self.assertEqual(cygrpc.CallError.ok, call_result)
+ event = queue.poll(deadline)
+ self.assertEqual(cygrpc.CompletionType.operation_complete,
+ event.type)
+ self.assertTrue(event.success)
+ self.assertIs(tag, event.tag)
+ except Exception as error:
+ raise Exception("Error in '{}': {}".format(description,
+ error.message))
+ return event
+
+ return test_utilities.SimpleFuture(performer)
+
+ def testEcho(self):
+ DEADLINE = time.time() + 5
+ DEADLINE_TOLERANCE = 0.25
+ CLIENT_METADATA_ASCII_KEY = b'key'
+ CLIENT_METADATA_ASCII_VALUE = b'val'
+ CLIENT_METADATA_BIN_KEY = b'key-bin'
+ CLIENT_METADATA_BIN_VALUE = b'\0' * 1000
+ SERVER_INITIAL_METADATA_KEY = b'init_me_me_me'
+ SERVER_INITIAL_METADATA_VALUE = b'whodawha?'
+ SERVER_TRAILING_METADATA_KEY = b'california_is_in_a_drought'
+ SERVER_TRAILING_METADATA_VALUE = b'zomg it is'
+ SERVER_STATUS_CODE = cygrpc.StatusCode.ok
+ SERVER_STATUS_DETAILS = b'our work is never over'
+ REQUEST = b'in death a member of project mayhem has a name'
+ RESPONSE = b'his name is robert paulson'
+ METHOD = b'twinkies'
+
+ cygrpc_deadline = cygrpc.Timespec(DEADLINE)
+
+ server_request_tag = object()
+ request_call_result = self.server.request_call(
+ self.server_completion_queue, self.server_completion_queue,
+ server_request_tag)
+
+ self.assertEqual(cygrpc.CallError.ok, request_call_result)
+
+ client_call_tag = object()
+ client_call = self.client_channel.create_call(
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
+ client_initial_metadata = cygrpc.Metadata([
+ cygrpc.Metadatum(CLIENT_METADATA_ASCII_KEY,
+ CLIENT_METADATA_ASCII_VALUE),
+ cygrpc.Metadatum(CLIENT_METADATA_BIN_KEY, CLIENT_METADATA_BIN_VALUE)
+ ])
+ client_start_batch_result = client_call.start_client_batch([
+ cygrpc.operation_send_initial_metadata(client_initial_metadata,
+ _EMPTY_FLAGS),
+ cygrpc.operation_send_message(REQUEST, _EMPTY_FLAGS),
+ cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
+ cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
+ ], client_call_tag)
+ self.assertEqual(cygrpc.CallError.ok, client_start_batch_result)
+ client_event_future = test_utilities.CompletionQueuePollFuture(
+ self.client_completion_queue, cygrpc_deadline)
+
+ request_event = self.server_completion_queue.poll(cygrpc_deadline)
+ self.assertEqual(cygrpc.CompletionType.operation_complete,
+ request_event.type)
+ self.assertIsInstance(request_event.operation_call, cygrpc.Call)
+ self.assertIs(server_request_tag, request_event.tag)
+ self.assertEqual(0, len(request_event.batch_operations))
self.assertTrue(
- test_common.metadata_transmitted(server_trailing_metadata,
- client_result.received_metadata))
- self.assertEqual(SERVER_STATUS_DETAILS,
- client_result.received_status_details)
- self.assertEqual(SERVER_STATUS_CODE, client_result.received_status_code)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.send_close_from_client,
- cygrpc.OperationType.receive_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.receive_status_on_client
- ]), found_client_op_types)
-
- self.assertEqual(5, len(server_event.batch_operations))
- found_server_op_types = set()
- for server_result in server_event.batch_operations:
- self.assertNotIn(client_result.type, found_server_op_types)
- found_server_op_types.add(server_result.type)
- if server_result.type == cygrpc.OperationType.receive_message:
- self.assertEqual(REQUEST, server_result.received_message.bytes())
- elif server_result.type == cygrpc.OperationType.receive_close_on_server:
- self.assertFalse(server_result.received_cancelled)
- self.assertEqual(set([
- cygrpc.OperationType.send_initial_metadata,
- cygrpc.OperationType.receive_message,
- cygrpc.OperationType.send_message,
- cygrpc.OperationType.receive_close_on_server,
- cygrpc.OperationType.send_status_from_server
- ]), found_server_op_types)
-
- del client_call
- del server_call
-
- def test6522(self):
- DEADLINE = time.time()+5
- DEADLINE_TOLERANCE = 0.25
- METHOD = b'twinkies'
-
- cygrpc_deadline = cygrpc.Timespec(DEADLINE)
- empty_metadata = cygrpc.Metadata([])
-
- server_request_tag = object()
- self.server.request_call(
- self.server_completion_queue, self.server_completion_queue,
- server_request_tag)
- client_call = self.client_channel.create_call(
- None, 0, self.client_completion_queue, METHOD, self.host_argument,
- cygrpc_deadline)
-
- # Prologue
- def perform_client_operations(operations, description):
- return self._perform_operations(
- operations, client_call,
- self.client_completion_queue, cygrpc_deadline, description)
-
- client_event_future = perform_client_operations([
+ test_common.metadata_transmitted(client_initial_metadata,
+ request_event.request_metadata))
+ self.assertEqual(METHOD, request_event.request_call_details.method)
+ self.assertEqual(self.expected_host,
+ request_event.request_call_details.host)
+ self.assertLess(
+ abs(DEADLINE - float(request_event.request_call_details.deadline)),
+ DEADLINE_TOLERANCE)
+
+ server_call_tag = object()
+ server_call = request_event.operation_call
+ server_initial_metadata = cygrpc.Metadata([
+ cygrpc.Metadatum(SERVER_INITIAL_METADATA_KEY,
+ SERVER_INITIAL_METADATA_VALUE)
+ ])
+ server_trailing_metadata = cygrpc.Metadata([
+ cygrpc.Metadatum(SERVER_TRAILING_METADATA_KEY,
+ SERVER_TRAILING_METADATA_VALUE)
+ ])
+ server_start_batch_result = server_call.start_server_batch([
+ cygrpc.operation_send_initial_metadata(
+ server_initial_metadata,
+ _EMPTY_FLAGS), cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ cygrpc.operation_send_message(RESPONSE, _EMPTY_FLAGS),
+ cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
+ cygrpc.operation_send_status_from_server(
+ server_trailing_metadata, SERVER_STATUS_CODE,
+ SERVER_STATUS_DETAILS, _EMPTY_FLAGS)
+ ], server_call_tag)
+ self.assertEqual(cygrpc.CallError.ok, server_start_batch_result)
+
+ server_event = self.server_completion_queue.poll(cygrpc_deadline)
+ client_event = client_event_future.result()
+
+ self.assertEqual(6, len(client_event.batch_operations))
+ found_client_op_types = set()
+ for client_result in client_event.batch_operations:
+ # we expect each op type to be unique
+ self.assertNotIn(client_result.type, found_client_op_types)
+ found_client_op_types.add(client_result.type)
+ if client_result.type == cygrpc.OperationType.receive_initial_metadata:
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ server_initial_metadata,
+ client_result.received_metadata))
+ elif client_result.type == cygrpc.OperationType.receive_message:
+ self.assertEqual(RESPONSE,
+ client_result.received_message.bytes())
+ elif client_result.type == cygrpc.OperationType.receive_status_on_client:
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ server_trailing_metadata,
+ client_result.received_metadata))
+ self.assertEqual(SERVER_STATUS_DETAILS,
+ client_result.received_status_details)
+ self.assertEqual(SERVER_STATUS_CODE,
+ client_result.received_status_code)
+ self.assertEqual(
+ set([
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.send_message,
+ cygrpc.OperationType.send_close_from_client,
+ cygrpc.OperationType.receive_initial_metadata,
+ cygrpc.OperationType.receive_message,
+ cygrpc.OperationType.receive_status_on_client
+ ]), found_client_op_types)
+
+ self.assertEqual(5, len(server_event.batch_operations))
+ found_server_op_types = set()
+ for server_result in server_event.batch_operations:
+ self.assertNotIn(client_result.type, found_server_op_types)
+ found_server_op_types.add(server_result.type)
+ if server_result.type == cygrpc.OperationType.receive_message:
+ self.assertEqual(REQUEST,
+ server_result.received_message.bytes())
+ elif server_result.type == cygrpc.OperationType.receive_close_on_server:
+ self.assertFalse(server_result.received_cancelled)
+ self.assertEqual(
+ set([
+ cygrpc.OperationType.send_initial_metadata,
+ cygrpc.OperationType.receive_message,
+ cygrpc.OperationType.send_message,
+ cygrpc.OperationType.receive_close_on_server,
+ cygrpc.OperationType.send_status_from_server
+ ]), found_server_op_types)
+
+ del client_call
+ del server_call
+
+ def test6522(self):
+ DEADLINE = time.time() + 5
+ DEADLINE_TOLERANCE = 0.25
+ METHOD = b'twinkies'
+
+ cygrpc_deadline = cygrpc.Timespec(DEADLINE)
+ empty_metadata = cygrpc.Metadata([])
+
+ server_request_tag = object()
+ self.server.request_call(self.server_completion_queue,
+ self.server_completion_queue,
+ server_request_tag)
+ client_call = self.client_channel.create_call(
+ None, 0, self.client_completion_queue, METHOD, self.host_argument,
+ cygrpc_deadline)
+
+ # Prologue
+ def perform_client_operations(operations, description):
+ return self._perform_operations(operations, client_call,
+ self.client_completion_queue,
+ cygrpc_deadline, description)
+
+ client_event_future = perform_client_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
_EMPTY_FLAGS),
cygrpc.operation_receive_initial_metadata(_EMPTY_FLAGS),
], "Client prologue")
- request_event = self.server_completion_queue.poll(cygrpc_deadline)
- server_call = request_event.operation_call
+ request_event = self.server_completion_queue.poll(cygrpc_deadline)
+ server_call = request_event.operation_call
- def perform_server_operations(operations, description):
- return self._perform_operations(
- operations, server_call,
- self.server_completion_queue, cygrpc_deadline, description)
+ def perform_server_operations(operations, description):
+ return self._perform_operations(operations, server_call,
+ self.server_completion_queue,
+ cygrpc_deadline, description)
- server_event_future = perform_server_operations([
+ server_event_future = perform_server_operations([
cygrpc.operation_send_initial_metadata(empty_metadata,
_EMPTY_FLAGS),
], "Server prologue")
- client_event_future.result() # force completion
- server_event_future.result()
-
- # Messaging
- for _ in range(10):
- client_event_future = perform_client_operations([
- cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- ], "Client message")
- server_event_future = perform_server_operations([
- cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
- cygrpc.operation_receive_message(_EMPTY_FLAGS),
- ], "Server receive")
-
- client_event_future.result() # force completion
- server_event_future.result()
-
- # Epilogue
- client_event_future = perform_client_operations([
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+ # Messaging
+ for _ in range(10):
+ client_event_future = perform_client_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Client message")
+ server_event_future = perform_server_operations([
+ cygrpc.operation_send_message(b'', _EMPTY_FLAGS),
+ cygrpc.operation_receive_message(_EMPTY_FLAGS),
+ ], "Server receive")
+
+ client_event_future.result() # force completion
+ server_event_future.result()
+
+ # Epilogue
+ client_event_future = perform_client_operations([
cygrpc.operation_send_close_from_client(_EMPTY_FLAGS),
cygrpc.operation_receive_status_on_client(_EMPTY_FLAGS)
], "Client epilogue")
- server_event_future = perform_server_operations([
+ server_event_future = perform_server_operations([
cygrpc.operation_receive_close_on_server(_EMPTY_FLAGS),
cygrpc.operation_send_status_from_server(
empty_metadata, cygrpc.StatusCode.ok, b'', _EMPTY_FLAGS)
], "Server epilogue")
- client_event_future.result() # force completion
- server_event_future.result()
+ client_event_future.result() # force completion
+ server_event_future.result()
class InsecureServerInsecureClient(unittest.TestCase, ServerClientMixin):
- def setUp(self):
- self.setUpMixin(None, None, None)
+ def setUp(self):
+ self.setUpMixin(None, None, None)
- def tearDown(self):
- self.tearDownMixin()
+ def tearDown(self):
+ self.tearDownMixin()
class SecureServerSecureClient(unittest.TestCase, ServerClientMixin):
- def setUp(self):
- server_credentials = cygrpc.server_credentials_ssl(
- None, [cygrpc.SslPemKeyCertPair(resources.private_key(),
- resources.certificate_chain())], False)
- client_credentials = cygrpc.channel_credentials_ssl(
- resources.test_root_certificates(), None)
- self.setUpMixin(server_credentials, client_credentials, _SSL_HOST_OVERRIDE)
+ def setUp(self):
+ server_credentials = cygrpc.server_credentials_ssl(None, [
+ cygrpc.SslPemKeyCertPair(resources.private_key(),
+ resources.certificate_chain())
+ ], False)
+ client_credentials = cygrpc.channel_credentials_ssl(
+ resources.test_root_certificates(), None)
+ self.setUpMixin(server_credentials, client_credentials,
+ _SSL_HOST_OVERRIDE)
- def tearDown(self):
- self.tearDownMixin()
+ def tearDown(self):
+ self.tearDownMixin()
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py b/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py
index 6280ce74c4..dffb3733b6 100644
--- a/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py
+++ b/src/python/grpcio_tests/tests/unit/_cython/test_utilities.py
@@ -33,34 +33,35 @@ from grpc._cython import cygrpc
class SimpleFuture(object):
- """A simple future mechanism."""
+ """A simple future mechanism."""
- def __init__(self, function, *args, **kwargs):
- def wrapped_function():
- try:
- self._result = function(*args, **kwargs)
- except Exception as error:
- self._error = error
- self._result = None
- self._error = None
- self._thread = threading.Thread(target=wrapped_function)
- self._thread.start()
+ def __init__(self, function, *args, **kwargs):
- def result(self):
- """The resulting value of this future.
+ def wrapped_function():
+ try:
+ self._result = function(*args, **kwargs)
+ except Exception as error:
+ self._error = error
+
+ self._result = None
+ self._error = None
+ self._thread = threading.Thread(target=wrapped_function)
+ self._thread.start()
+
+ def result(self):
+ """The resulting value of this future.
Re-raises any exceptions.
"""
- self._thread.join()
- if self._error:
- # TODO(atash): re-raise exceptions in a way that preserves tracebacks
- raise self._error
- return self._result
+ self._thread.join()
+ if self._error:
+ # TODO(atash): re-raise exceptions in a way that preserves tracebacks
+ raise self._error
+ return self._result
class CompletionQueuePollFuture(SimpleFuture):
- def __init__(self, completion_queue, deadline):
- super(CompletionQueuePollFuture, self).__init__(
- lambda: completion_queue.poll(deadline))
-
+ def __init__(self, completion_queue, deadline):
+ super(CompletionQueuePollFuture,
+ self).__init__(lambda: completion_queue.poll(deadline))
diff --git a/src/python/grpcio_tests/tests/unit/_empty_message_test.py b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
index 69f4689279..4588688ea6 100644
--- a/src/python/grpcio_tests/tests/unit/_empty_message_test.py
+++ b/src/python/grpcio_tests/tests/unit/_empty_message_test.py
@@ -44,95 +44,94 @@ _STREAM_STREAM = '/test/StreamStream'
def handle_unary_unary(request, servicer_context):
- return _RESPONSE
+ return _RESPONSE
def handle_unary_stream(request, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH):
- yield _RESPONSE
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
def handle_stream_unary(request_iterator, servicer_context):
- for request in request_iterator:
- pass
- return _RESPONSE
+ for request in request_iterator:
+ pass
+ return _RESPONSE
def handle_stream_stream(request_iterator, servicer_context):
- for request in request_iterator:
- yield _RESPONSE
+ for request in request_iterator:
+ yield _RESPONSE
class _MethodHandler(grpc.RpcMethodHandler):
- def __init__(self, request_streaming, response_streaming):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = None
- self.response_serializer = None
- self.unary_unary = None
- self.unary_stream = None
- self.stream_unary = None
- self.stream_stream = None
- if self.request_streaming and self.response_streaming:
- self.stream_stream = handle_stream_stream
- elif self.request_streaming:
- self.stream_unary = handle_stream_unary
- elif self.response_streaming:
- self.unary_stream = handle_unary_stream
- else:
- self.unary_unary = handle_unary_unary
+ def __init__(self, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = handle_stream_stream
+ elif self.request_streaming:
+ self.stream_unary = handle_stream_unary
+ elif self.response_streaming:
+ self.unary_stream = handle_unary_stream
+ else:
+ self.unary_unary = handle_unary_unary
class _GenericHandler(grpc.GenericRpcHandler):
- def service(self, handler_call_details):
- if handler_call_details.method == _UNARY_UNARY:
- return _MethodHandler(False, False)
- elif handler_call_details.method == _UNARY_STREAM:
- return _MethodHandler(False, True)
- elif handler_call_details.method == _STREAM_UNARY:
- return _MethodHandler(True, False)
- elif handler_call_details.method == _STREAM_STREAM:
- return _MethodHandler(True, True)
- else:
- return None
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True)
+ else:
+ return None
class EmptyMessageTest(unittest.TestCase):
- def setUp(self):
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(
- self._server_pool, handlers=(_GenericHandler(),))
- port = self._server.add_insecure_port('[::]:0')
- self._server.start()
- self._channel = grpc.insecure_channel('localhost:%d' % port)
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ self._server_pool, handlers=(_GenericHandler(),))
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
- def tearDown(self):
- self._server.stop(0)
+ def tearDown(self):
+ self._server.stop(0)
- def testUnaryUnary(self):
- response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
- self.assertEqual(_RESPONSE, response)
+ def testUnaryUnary(self):
+ response = self._channel.unary_unary(_UNARY_UNARY)(_REQUEST)
+ self.assertEqual(_RESPONSE, response)
- def testUnaryStream(self):
- response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
- self.assertSequenceEqual(
- [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+ def testUnaryStream(self):
+ response_iterator = self._channel.unary_stream(_UNARY_STREAM)(_REQUEST)
+ self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH,
+ list(response_iterator))
- def testStreamUnary(self):
- response = self._channel.stream_unary(_STREAM_UNARY)(
- iter([_REQUEST] * test_constants.STREAM_LENGTH))
- self.assertEqual(_RESPONSE, response)
+ def testStreamUnary(self):
+ response = self._channel.stream_unary(_STREAM_UNARY)(iter(
+ [_REQUEST] * test_constants.STREAM_LENGTH))
+ self.assertEqual(_RESPONSE, response)
- def testStreamStream(self):
- response_iterator = self._channel.stream_stream(_STREAM_STREAM)(
- iter([_REQUEST] * test_constants.STREAM_LENGTH))
- self.assertSequenceEqual(
- [_RESPONSE] * test_constants.STREAM_LENGTH, list(response_iterator))
+ def testStreamStream(self):
+ response_iterator = self._channel.stream_stream(_STREAM_STREAM)(iter(
+ [_REQUEST] * test_constants.STREAM_LENGTH))
+ self.assertSequenceEqual([_RESPONSE] * test_constants.STREAM_LENGTH,
+ list(response_iterator))
if __name__ == '__main__':
- unittest.main(verbosity=2)
-
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
index 777527137f..22a6643848 100644
--- a/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
+++ b/src/python/grpcio_tests/tests/unit/_exit_scenarios.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Defines a number of module-scope gRPC scenarios to test clean exit."""
import argparse
@@ -73,88 +72,88 @@ TEST_TO_METHOD = {
def hang_unary_unary(request, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_unary_stream(request, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_unary_stream(request, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- yield request
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield request
+ time.sleep(WAIT_TIME)
def hang_stream_unary(request_iterator, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_stream_unary(request_iterator, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- next(request_iterator)
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(request_iterator)
+ time.sleep(WAIT_TIME)
def hang_stream_stream(request_iterator, servicer_context):
- time.sleep(WAIT_TIME)
+ time.sleep(WAIT_TIME)
def hang_partial_stream_stream(request_iterator, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH // 2):
- yield next(request_iterator)
- time.sleep(WAIT_TIME)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield next(request_iterator)
+ time.sleep(WAIT_TIME)
class MethodHandler(grpc.RpcMethodHandler):
- def __init__(self, request_streaming, response_streaming, partial_hang):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = None
- self.response_serializer = None
- self.unary_unary = None
- self.unary_stream = None
- self.stream_unary = None
- self.stream_stream = None
- if self.request_streaming and self.response_streaming:
- if partial_hang:
- self.stream_stream = hang_partial_stream_stream
- else:
- self.stream_stream = hang_stream_stream
- elif self.request_streaming:
- if partial_hang:
- self.stream_unary = hang_partial_stream_unary
- else:
- self.stream_unary = hang_stream_unary
- elif self.response_streaming:
- if partial_hang:
- self.unary_stream = hang_partial_unary_stream
- else:
- self.unary_stream = hang_unary_stream
- else:
- self.unary_unary = hang_unary_unary
+ def __init__(self, request_streaming, response_streaming, partial_hang):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ if partial_hang:
+ self.stream_stream = hang_partial_stream_stream
+ else:
+ self.stream_stream = hang_stream_stream
+ elif self.request_streaming:
+ if partial_hang:
+ self.stream_unary = hang_partial_stream_unary
+ else:
+ self.stream_unary = hang_stream_unary
+ elif self.response_streaming:
+ if partial_hang:
+ self.unary_stream = hang_partial_unary_stream
+ else:
+ self.unary_stream = hang_unary_stream
+ else:
+ self.unary_unary = hang_unary_unary
class GenericHandler(grpc.GenericRpcHandler):
- def service(self, handler_call_details):
- if handler_call_details.method == UNARY_UNARY:
- return MethodHandler(False, False, False)
- elif handler_call_details.method == UNARY_STREAM:
- return MethodHandler(False, True, False)
- elif handler_call_details.method == STREAM_UNARY:
- return MethodHandler(True, False, False)
- elif handler_call_details.method == STREAM_STREAM:
- return MethodHandler(True, True, False)
- elif handler_call_details.method == PARTIAL_UNARY_STREAM:
- return MethodHandler(False, True, True)
- elif handler_call_details.method == PARTIAL_STREAM_UNARY:
- return MethodHandler(True, False, True)
- elif handler_call_details.method == PARTIAL_STREAM_STREAM:
- return MethodHandler(True, True, True)
- else:
- return None
+ def service(self, handler_call_details):
+ if handler_call_details.method == UNARY_UNARY:
+ return MethodHandler(False, False, False)
+ elif handler_call_details.method == UNARY_STREAM:
+ return MethodHandler(False, True, False)
+ elif handler_call_details.method == STREAM_UNARY:
+ return MethodHandler(True, False, False)
+ elif handler_call_details.method == STREAM_STREAM:
+ return MethodHandler(True, True, False)
+ elif handler_call_details.method == PARTIAL_UNARY_STREAM:
+ return MethodHandler(False, True, True)
+ elif handler_call_details.method == PARTIAL_STREAM_UNARY:
+ return MethodHandler(True, False, True)
+ elif handler_call_details.method == PARTIAL_STREAM_STREAM:
+ return MethodHandler(True, True, True)
+ else:
+ return None
# Traditional executors will not exit until all their
@@ -162,88 +161,88 @@ class GenericHandler(grpc.GenericRpcHandler):
# never finish, we don't want to block exit on these jobs.
class DaemonPool(object):
- def submit(self, fn, *args, **kwargs):
- thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
- thread.daemon = True
- thread.start()
+ def submit(self, fn, *args, **kwargs):
+ thread = threading.Thread(target=fn, args=args, kwargs=kwargs)
+ thread.daemon = True
+ thread.start()
- def shutdown(self, wait=True):
- pass
+ def shutdown(self, wait=True):
+ pass
def infinite_request_iterator():
- while True:
- yield REQUEST
+ while True:
+ yield REQUEST
if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('scenario', type=str)
- parser.add_argument(
- '--wait_for_interrupt', dest='wait_for_interrupt', action='store_true')
- args = parser.parse_args()
-
- if args.scenario == UNSTARTED_SERVER:
- server = grpc.server(DaemonPool())
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == RUNNING_SERVER:
- server = grpc.server(DaemonPool())
- port = server.add_insecure_port('[::]:0')
- server.start()
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
- channel = grpc.insecure_channel('localhost:12345')
-
- def connectivity_callback(connectivity):
- pass
-
- channel.subscribe(connectivity_callback, try_to_connect=True)
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
- elif args.scenario == POLL_CONNECTIVITY:
- server = grpc.server(DaemonPool())
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
-
- def connectivity_callback(connectivity):
- pass
-
- channel.subscribe(connectivity_callback, try_to_connect=True)
- if args.wait_for_interrupt:
- time.sleep(WAIT_TIME)
-
- else:
- handler = GenericHandler()
- server = grpc.server(DaemonPool())
- port = server.add_insecure_port('[::]:0')
- server.add_generic_rpc_handlers((handler,))
- server.start()
- channel = grpc.insecure_channel('localhost:%d' % port)
-
- method = TEST_TO_METHOD[args.scenario]
-
- if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
- multi_callable = channel.unary_unary(method)
- future = multi_callable.future(REQUEST)
- result, call = multi_callable.with_call(REQUEST)
- elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or
- args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL):
- multi_callable = channel.unary_stream(method)
- response_iterator = multi_callable(REQUEST)
- for response in response_iterator:
- pass
- elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or
- args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL):
- multi_callable = channel.stream_unary(method)
- future = multi_callable.future(infinite_request_iterator())
- result, call = multi_callable.with_call(
- iter([REQUEST] * test_constants.STREAM_LENGTH))
- elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or
- args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL):
- multi_callable = channel.stream_stream(method)
- response_iterator = multi_callable(infinite_request_iterator())
- for response in response_iterator:
- pass
+ parser = argparse.ArgumentParser()
+ parser.add_argument('scenario', type=str)
+ parser.add_argument(
+ '--wait_for_interrupt', dest='wait_for_interrupt', action='store_true')
+ args = parser.parse_args()
+
+ if args.scenario == UNSTARTED_SERVER:
+ server = grpc.server(DaemonPool())
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == RUNNING_SERVER:
+ server = grpc.server(DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY_NO_SERVER:
+ channel = grpc.insecure_channel('localhost:12345')
+
+ def connectivity_callback(connectivity):
+ pass
+
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+ elif args.scenario == POLL_CONNECTIVITY:
+ server = grpc.server(DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def connectivity_callback(connectivity):
+ pass
+
+ channel.subscribe(connectivity_callback, try_to_connect=True)
+ if args.wait_for_interrupt:
+ time.sleep(WAIT_TIME)
+
+ else:
+ handler = GenericHandler()
+ server = grpc.server(DaemonPool())
+ port = server.add_insecure_port('[::]:0')
+ server.add_generic_rpc_handlers((handler,))
+ server.start()
+ channel = grpc.insecure_channel('localhost:%d' % port)
+
+ method = TEST_TO_METHOD[args.scenario]
+
+ if args.scenario == IN_FLIGHT_UNARY_UNARY_CALL:
+ multi_callable = channel.unary_unary(method)
+ future = multi_callable.future(REQUEST)
+ result, call = multi_callable.with_call(REQUEST)
+ elif (args.scenario == IN_FLIGHT_UNARY_STREAM_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL):
+ multi_callable = channel.unary_stream(method)
+ response_iterator = multi_callable(REQUEST)
+ for response in response_iterator:
+ pass
+ elif (args.scenario == IN_FLIGHT_STREAM_UNARY_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL):
+ multi_callable = channel.stream_unary(method)
+ future = multi_callable.future(infinite_request_iterator())
+ result, call = multi_callable.with_call(
+ iter([REQUEST] * test_constants.STREAM_LENGTH))
+ elif (args.scenario == IN_FLIGHT_STREAM_STREAM_CALL or
+ args.scenario == IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL):
+ multi_callable = channel.stream_stream(method)
+ response_iterator = multi_callable(infinite_request_iterator())
+ for response in response_iterator:
+ pass
diff --git a/src/python/grpcio_tests/tests/unit/_exit_test.py b/src/python/grpcio_tests/tests/unit/_exit_test.py
index 5a4a32887c..b99605dcb8 100644
--- a/src/python/grpcio_tests/tests/unit/_exit_test.py
+++ b/src/python/grpcio_tests/tests/unit/_exit_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests clean exit of server/client on Python Interpreter exit/sigint.
The tests in this module spawn a subprocess for each test case, the
@@ -45,15 +44,15 @@ import unittest
from tests.unit import _exit_scenarios
-SCENARIO_FILE = os.path.abspath(os.path.join(
- os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py'))
+SCENARIO_FILE = os.path.abspath(
+ os.path.join(
+ os.path.dirname(os.path.realpath(__file__)), '_exit_scenarios.py'))
INTERPRETER = sys.executable
BASE_COMMAND = [INTERPRETER, SCENARIO_FILE]
BASE_SIGTERM_COMMAND = BASE_COMMAND + ['--wait_for_interrupt']
INIT_TIME = 1.0
-
processes = []
process_lock = threading.Lock()
@@ -61,126 +60,146 @@ process_lock = threading.Lock()
# Make sure we attempt to clean up any
# processes we may have left running
def cleanup_processes():
- with process_lock:
- for process in processes:
- try:
- process.kill()
- except Exception:
- pass
+ with process_lock:
+ for process in processes:
+ try:
+ process.kill()
+ except Exception:
+ pass
+
+
atexit.register(cleanup_processes)
def interrupt_and_wait(process):
- with process_lock:
- processes.append(process)
- time.sleep(INIT_TIME)
- os.kill(process.pid, signal.SIGINT)
- process.wait()
+ with process_lock:
+ processes.append(process)
+ time.sleep(INIT_TIME)
+ os.kill(process.pid, signal.SIGINT)
+ process.wait()
def wait(process):
- with process_lock:
- processes.append(process)
- process.wait()
+ with process_lock:
+ processes.append(process)
+ process.wait()
@unittest.skip('https://github.com/grpc/grpc/issues/7311')
class ExitTest(unittest.TestCase):
- def test_unstarted_server(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
- stdout=sys.stdout, stderr=sys.stderr)
- wait(process)
-
- def test_unstarted_server_terminate(self):
- process = subprocess.Popen(
- BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
- stdout=sys.stdout)
- interrupt_and_wait(process)
-
- def test_running_server(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER],
- stdout=sys.stdout, stderr=sys.stderr)
- wait(process)
-
- def test_running_server_terminate(self):
- process = subprocess.Popen(
- BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- def test_poll_connectivity_no_server(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
- stdout=sys.stdout, stderr=sys.stderr)
- wait(process)
-
- def test_poll_connectivity_no_server_terminate(self):
- process = subprocess.Popen(
- BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- def test_poll_connectivity(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
- stdout=sys.stdout, stderr=sys.stderr)
- wait(process)
-
- def test_poll_connectivity_terminate(self):
- process = subprocess.Popen(
- BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- def test_in_flight_unary_unary_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
- def test_in_flight_unary_stream_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- def test_in_flight_stream_unary_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
- def test_in_flight_stream_stream_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
- def test_in_flight_partial_unary_stream_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- def test_in_flight_partial_stream_unary_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
-
- @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
- def test_in_flight_partial_stream_stream_call(self):
- process = subprocess.Popen(
- BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
- stdout=sys.stdout, stderr=sys.stderr)
- interrupt_and_wait(process)
+ def test_unstarted_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ def test_unstarted_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.UNSTARTED_SERVER],
+ stdout=sys.stdout)
+ interrupt_and_wait(process)
+
+ def test_running_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.RUNNING_SERVER],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ def test_running_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.RUNNING_SERVER],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_poll_connectivity_no_server(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ def test_poll_connectivity_no_server_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND +
+ [_exit_scenarios.POLL_CONNECTIVITY_NO_SERVER],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_poll_connectivity(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ wait(process)
+
+ def test_poll_connectivity_terminate(self):
+ process = subprocess.Popen(
+ BASE_SIGTERM_COMMAND + [_exit_scenarios.POLL_CONNECTIVITY],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_unary_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_UNARY_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_unary_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_UNARY_STREAM_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_stream_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_UNARY_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_stream_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND + [_exit_scenarios.IN_FLIGHT_STREAM_STREAM_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_partial_unary_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND +
+ [_exit_scenarios.IN_FLIGHT_PARTIAL_UNARY_STREAM_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ def test_in_flight_partial_stream_unary_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND +
+ [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_UNARY_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
+
+ @unittest.skipIf(six.PY2, 'https://github.com/grpc/grpc/issues/6999')
+ def test_in_flight_partial_stream_stream_call(self):
+ process = subprocess.Popen(
+ BASE_COMMAND +
+ [_exit_scenarios.IN_FLIGHT_PARTIAL_STREAM_STREAM_CALL],
+ stdout=sys.stdout,
+ stderr=sys.stderr)
+ interrupt_and_wait(process)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
index 2dc225de29..1b1b1bd598 100644
--- a/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invalid_metadata_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test of RPCs made against gRPC Python's application-layer API."""
import unittest
@@ -47,129 +46,131 @@ _STREAM_STREAM = '/test/StreamStream'
def _unary_unary_multi_callable(channel):
- return channel.unary_unary(_UNARY_UNARY)
+ return channel.unary_unary(_UNARY_UNARY)
def _unary_stream_multi_callable(channel):
- return channel.unary_stream(
- _UNARY_STREAM,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.unary_stream(
+ _UNARY_STREAM,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
- return channel.stream_unary(
- _STREAM_UNARY,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.stream_unary(
+ _STREAM_UNARY,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_stream_multi_callable(channel):
- return channel.stream_stream(_STREAM_STREAM)
+ return channel.stream_stream(_STREAM_STREAM)
class InvalidMetadataTest(unittest.TestCase):
- def setUp(self):
- self._channel = grpc.insecure_channel('localhost:8080')
- self._unary_unary = _unary_unary_multi_callable(self._channel)
- self._unary_stream = _unary_stream_multi_callable(self._channel)
- self._stream_unary = _stream_unary_multi_callable(self._channel)
- self._stream_stream = _stream_stream_multi_callable(self._channel)
-
- def testUnaryRequestBlockingUnaryResponse(self):
- request = b'\x07\x08'
- metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- with self.assertRaises(ValueError) as exception_context:
- self._unary_unary(request, metadata=metadata)
- self.assertIn(expected_error_details, str(exception_context.exception))
-
- def testUnaryRequestBlockingUnaryResponseWithCall(self):
- request = b'\x07\x08'
- metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- with self.assertRaises(ValueError) as exception_context:
- self._unary_unary.with_call(request, metadata=metadata)
- self.assertIn(expected_error_details, str(exception_context.exception))
-
- def testUnaryRequestFutureUnaryResponse(self):
- request = b'\x07\x08'
- metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- response_future = self._unary_unary.future(request, metadata=metadata)
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertEqual(
- exception_context.exception.details(), expected_error_details)
- self.assertEqual(
- exception_context.exception.code(), grpc.StatusCode.INTERNAL)
- self.assertEqual(response_future.details(), expected_error_details)
- self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
-
- def testUnaryRequestStreamResponse(self):
- request = b'\x37\x58'
- metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- response_iterator = self._unary_stream(request, metadata=metadata)
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(response_iterator)
- self.assertEqual(
- exception_context.exception.details(), expected_error_details)
- self.assertEqual(
- exception_context.exception.code(), grpc.StatusCode.INTERNAL)
- self.assertEqual(response_iterator.details(), expected_error_details)
- self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
-
- def testStreamRequestBlockingUnaryResponse(self):
- request_iterator = (b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- with self.assertRaises(ValueError) as exception_context:
- self._stream_unary(request_iterator, metadata=metadata)
- self.assertIn(expected_error_details, str(exception_context.exception))
-
- def testStreamRequestBlockingUnaryResponseWithCall(self):
- request_iterator = (
- b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self.assertRaises(ValueError) as exception_context:
- multi_callable.with_call(request_iterator, metadata=metadata)
- self.assertIn(expected_error_details, str(exception_context.exception))
-
- def testStreamRequestFutureUnaryResponse(self):
- request_iterator = (
- b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- response_future = self._stream_unary.future(
- request_iterator, metadata=metadata)
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertEqual(
- exception_context.exception.details(), expected_error_details)
- self.assertEqual(
- exception_context.exception.code(), grpc.StatusCode.INTERNAL)
- self.assertEqual(response_future.details(), expected_error_details)
- self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
-
- def testStreamRequestStreamResponse(self):
- request_iterator = (
- b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)
- expected_error_details = "metadata was invalid: %s" % metadata
- response_iterator = self._stream_stream(request_iterator, metadata=metadata)
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(response_iterator)
- self.assertEqual(
- exception_context.exception.details(), expected_error_details)
- self.assertEqual(
- exception_context.exception.code(), grpc.StatusCode.INTERNAL)
- self.assertEqual(response_iterator.details(), expected_error_details)
- self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
+ def setUp(self):
+ self._channel = grpc.insecure_channel('localhost:8080')
+ self._unary_unary = _unary_unary_multi_callable(self._channel)
+ self._unary_stream = _unary_stream_multi_callable(self._channel)
+ self._stream_unary = _stream_unary_multi_callable(self._channel)
+ self._stream_stream = _stream_stream_multi_callable(self._channel)
+
+ def testUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x08'
+ metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ with self.assertRaises(ValueError) as exception_context:
+ self._unary_unary(request, metadata=metadata)
+ self.assertIn(expected_error_details, str(exception_context.exception))
+
+ def testUnaryRequestBlockingUnaryResponseWithCall(self):
+ request = b'\x07\x08'
+ metadata = (('InVaLiD', 'UnaryRequestBlockingUnaryResponseWithCall'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ with self.assertRaises(ValueError) as exception_context:
+ self._unary_unary.with_call(request, metadata=metadata)
+ self.assertIn(expected_error_details, str(exception_context.exception))
+
+ def testUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x08'
+ metadata = (('InVaLiD', 'UnaryRequestFutureUnaryResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ response_future = self._unary_unary.future(request, metadata=metadata)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertEqual(exception_context.exception.details(),
+ expected_error_details)
+ self.assertEqual(exception_context.exception.code(),
+ grpc.StatusCode.INTERNAL)
+ self.assertEqual(response_future.details(), expected_error_details)
+ self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
+
+ def testUnaryRequestStreamResponse(self):
+ request = b'\x37\x58'
+ metadata = (('InVaLiD', 'UnaryRequestStreamResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ response_iterator = self._unary_stream(request, metadata=metadata)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(response_iterator)
+ self.assertEqual(exception_context.exception.details(),
+ expected_error_details)
+ self.assertEqual(exception_context.exception.code(),
+ grpc.StatusCode.INTERNAL)
+ self.assertEqual(response_iterator.details(), expected_error_details)
+ self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
+
+ def testStreamRequestBlockingUnaryResponse(self):
+ request_iterator = (b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ with self.assertRaises(ValueError) as exception_context:
+ self._stream_unary(request_iterator, metadata=metadata)
+ self.assertIn(expected_error_details, str(exception_context.exception))
+
+ def testStreamRequestBlockingUnaryResponseWithCall(self):
+ request_iterator = (b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ metadata = (('InVaLiD', 'StreamRequestBlockingUnaryResponseWithCall'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self.assertRaises(ValueError) as exception_context:
+ multi_callable.with_call(request_iterator, metadata=metadata)
+ self.assertIn(expected_error_details, str(exception_context.exception))
+
+ def testStreamRequestFutureUnaryResponse(self):
+ request_iterator = (b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ metadata = (('InVaLiD', 'StreamRequestFutureUnaryResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ response_future = self._stream_unary.future(
+ request_iterator, metadata=metadata)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertEqual(exception_context.exception.details(),
+ expected_error_details)
+ self.assertEqual(exception_context.exception.code(),
+ grpc.StatusCode.INTERNAL)
+ self.assertEqual(response_future.details(), expected_error_details)
+ self.assertEqual(response_future.code(), grpc.StatusCode.INTERNAL)
+
+ def testStreamRequestStreamResponse(self):
+ request_iterator = (b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ metadata = (('InVaLiD', 'StreamRequestStreamResponse'),)
+ expected_error_details = "metadata was invalid: %s" % metadata
+ response_iterator = self._stream_stream(
+ request_iterator, metadata=metadata)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(response_iterator)
+ self.assertEqual(exception_context.exception.details(),
+ expected_error_details)
+ self.assertEqual(exception_context.exception.code(),
+ grpc.StatusCode.INTERNAL)
+ self.assertEqual(response_iterator.details(), expected_error_details)
+ self.assertEqual(response_iterator.code(), grpc.StatusCode.INTERNAL)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
index 4312679bb9..efeb237874 100644
--- a/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
+++ b/src/python/grpcio_tests/tests/unit/_invocation_defects_test.py
@@ -50,106 +50,117 @@ _STREAM_STREAM = '/test/StreamStream'
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._value = None
- self._called = False
- def __call__(self, value):
- with self._condition:
- self._value = value
- self._called = True
- self._condition.notify_all()
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
+ self._called = False
- def value(self):
- with self._condition:
- while not self._called:
- self._condition.wait()
- return self._value
+ def __call__(self, value):
+ with self._condition:
+ self._value = value
+ self._called = True
+ self._condition.notify_all()
+
+ def value(self):
+ with self._condition:
+ while not self._called:
+ self._condition.wait()
+ return self._value
class _Handler(object):
- def __init__(self, control):
- self._control = control
-
- def handle_unary_unary(self, request, servicer_context):
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- return request
-
- def handle_unary_stream(self, request, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH):
- self._control.control()
- yield request
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
-
- def handle_stream_unary(self, request_iterator, servicer_context):
- if servicer_context is not None:
- servicer_context.invocation_metadata()
- self._control.control()
- response_elements = []
- for request in request_iterator:
- self._control.control()
- response_elements.append(request)
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- return b''.join(response_elements)
-
- def handle_stream_stream(self, request_iterator, servicer_context):
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- for request in request_iterator:
- self._control.control()
- yield request
- self._control.control()
+
+ def __init__(self, control):
+ self._control = control
+
+ def handle_unary_unary(self, request, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ return request
+
+ def handle_unary_stream(self, request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ self._control.control()
+ yield request
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+
+ def handle_stream_unary(self, request_iterator, servicer_context):
+ if servicer_context is not None:
+ servicer_context.invocation_metadata()
+ self._control.control()
+ response_elements = []
+ for request in request_iterator:
+ self._control.control()
+ response_elements.append(request)
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ return b''.join(response_elements)
+
+ def handle_stream_stream(self, request_iterator, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ for request in request_iterator:
+ self._control.control()
+ yield request
+ self._control.control()
class _MethodHandler(grpc.RpcMethodHandler):
- def __init__(
- self, request_streaming, response_streaming, request_deserializer,
- response_serializer, unary_unary, unary_stream, stream_unary,
- stream_stream):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = request_deserializer
- self.response_serializer = response_serializer
- self.unary_unary = unary_unary
- self.unary_stream = unary_stream
- self.stream_unary = stream_unary
- self.stream_stream = stream_stream
+
+ def __init__(self, request_streaming, response_streaming,
+ request_deserializer, response_serializer, unary_unary,
+ unary_stream, stream_unary, stream_stream):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = request_deserializer
+ self.response_serializer = response_serializer
+ self.unary_unary = unary_unary
+ self.unary_stream = unary_stream
+ self.stream_unary = stream_unary
+ self.stream_stream = stream_stream
class _GenericHandler(grpc.GenericRpcHandler):
- def __init__(self, handler):
- self._handler = handler
-
- def service(self, handler_call_details):
- if handler_call_details.method == _UNARY_UNARY:
- return _MethodHandler(
- False, False, None, None, self._handler.handle_unary_unary, None,
- None, None)
- elif handler_call_details.method == _UNARY_STREAM:
- return _MethodHandler(
- False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
- self._handler.handle_unary_stream, None, None)
- elif handler_call_details.method == _STREAM_UNARY:
- return _MethodHandler(
- True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None,
- self._handler.handle_stream_unary, None)
- elif handler_call_details.method == _STREAM_STREAM:
- return _MethodHandler(
- True, True, None, None, None, None, None,
- self._handler.handle_stream_stream)
- else:
- return None
+
+ def __init__(self, handler):
+ self._handler = handler
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False, None, None,
+ self._handler.handle_unary_unary, None, None,
+ None)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None,
+ self._handler.handle_unary_stream, None, None)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None, None,
+ self._handler.handle_stream_unary, None)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True, None, None, None, None, None,
+ self._handler.handle_stream_stream)
+ else:
+ return None
class FailAfterFewIterationsCounter(object):
+
def __init__(self, high, bytestring):
self._current = 0
self._high = high
@@ -167,81 +178,82 @@ class FailAfterFewIterationsCounter(object):
def _unary_unary_multi_callable(channel):
- return channel.unary_unary(_UNARY_UNARY)
+ return channel.unary_unary(_UNARY_UNARY)
def _unary_stream_multi_callable(channel):
- return channel.unary_stream(
- _UNARY_STREAM,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.unary_stream(
+ _UNARY_STREAM,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
- return channel.stream_unary(
- _STREAM_UNARY,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.stream_unary(
+ _STREAM_UNARY,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_stream_multi_callable(channel):
- return channel.stream_stream(_STREAM_STREAM)
+ return channel.stream_stream(_STREAM_STREAM)
class InvocationDefectsTest(unittest.TestCase):
- def setUp(self):
- self._control = test_control.PauseFailControl()
- self._handler = _Handler(self._control)
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
-
- self._server = grpc.server(self._server_pool)
- port = self._server.add_insecure_port('[::]:0')
- self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
- self._server.start()
-
- self._channel = grpc.insecure_channel('localhost:%d' % port)
-
- def tearDown(self):
- self._server.stop(0)
-
- def testIterableStreamRequestBlockingUnaryResponse(self):
- requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
- multi_callable = _stream_unary_multi_callable(self._channel)
-
- with self.assertRaises(grpc.RpcError):
- response = multi_callable(
- requests,
- metadata=(('test', 'IterableStreamRequestBlockingUnaryResponse'),))
-
- def testIterableStreamRequestFutureUnaryResponse(self):
- requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
- multi_callable = _stream_unary_multi_callable(self._channel)
- response_future = multi_callable.future(
- requests,
- metadata=(
- ('test', 'IterableStreamRequestFutureUnaryResponse'),))
-
- with self.assertRaises(grpc.RpcError):
- response = response_future.result()
-
- def testIterableStreamRequestStreamResponse(self):
- requests = [b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)]
- multi_callable = _stream_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- requests,
- metadata=(('test', 'IterableStreamRequestStreamResponse'),))
-
- with self.assertRaises(grpc.RpcError):
- next(response_iterator)
-
- def testIteratorStreamRequestStreamResponse(self):
- requests_iterator = FailAfterFewIterationsCounter(
- test_constants.STREAM_LENGTH // 2, b'\x07\x08')
- multi_callable = _stream_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- requests_iterator,
- metadata=(('test', 'IteratorStreamRequestStreamResponse'),))
-
- with self.assertRaises(grpc.RpcError):
- for _ in range(test_constants.STREAM_LENGTH // 2 + 1):
- next(response_iterator)
+
+ def setUp(self):
+ self._control = test_control.PauseFailControl()
+ self._handler = _Handler(self._control)
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+
+ self._server = grpc.server(self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
+ self._server.start()
+
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testIterableStreamRequestBlockingUnaryResponse(self):
+ requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
+ multi_callable = _stream_unary_multi_callable(self._channel)
+
+ with self.assertRaises(grpc.RpcError):
+ response = multi_callable(
+ requests,
+ metadata=(
+ ('test', 'IterableStreamRequestBlockingUnaryResponse'),))
+
+ def testIterableStreamRequestFutureUnaryResponse(self):
+ requests = [b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH)]
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response_future = multi_callable.future(
+ requests,
+ metadata=(('test', 'IterableStreamRequestFutureUnaryResponse'),))
+
+ with self.assertRaises(grpc.RpcError):
+ response = response_future.result()
+
+ def testIterableStreamRequestStreamResponse(self):
+ requests = [b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH)]
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ requests,
+ metadata=(('test', 'IterableStreamRequestStreamResponse'),))
+
+ with self.assertRaises(grpc.RpcError):
+ next(response_iterator)
+
+ def testIteratorStreamRequestStreamResponse(self):
+ requests_iterator = FailAfterFewIterationsCounter(
+ test_constants.STREAM_LENGTH // 2, b'\x07\x08')
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ requests_iterator,
+ metadata=(('test', 'IteratorStreamRequestStreamResponse'),))
+
+ with self.assertRaises(grpc.RpcError):
+ for _ in range(test_constants.STREAM_LENGTH // 2 + 1):
+ next(response_iterator)
diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/_junkdrawer/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py b/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py
index eef18f82d6..70f437bc83 100644
--- a/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py
+++ b/src/python/grpcio_tests/tests/unit/_junkdrawer/stock_pb2.py
@@ -35,7 +35,7 @@
# source: stock.proto
import sys
-_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
+_b = sys.version_info[0] < 3 and (lambda x: x) or (lambda x: x.encode('latin1'))
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
@@ -45,108 +45,135 @@ from google.protobuf import descriptor_pb2
_sym_db = _symbol_database.Default()
-
-
-
DESCRIPTOR = _descriptor.FileDescriptor(
- name='stock.proto',
- package='stock',
- serialized_pb=_b('\n\x0bstock.proto\x12\x05stock\">\n\x0cStockRequest\x12\x0e\n\x06symbol\x18\x01 \x01(\t\x12\x1e\n\x13num_trades_to_watch\x18\x02 \x01(\x05:\x01\x30\"+\n\nStockReply\x12\r\n\x05price\x18\x01 \x01(\x02\x12\x0e\n\x06symbol\x18\x02 \x01(\t2\x96\x02\n\x05Stock\x12=\n\x11GetLastTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x12I\n\x19GetLastTradePriceMultiple\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01\x30\x01\x12?\n\x11WatchFutureTrades\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x30\x01\x12\x42\n\x14GetHighestTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01')
-)
+ name='stock.proto',
+ package='stock',
+ serialized_pb=_b(
+ '\n\x0bstock.proto\x12\x05stock\">\n\x0cStockRequest\x12\x0e\n\x06symbol\x18\x01 \x01(\t\x12\x1e\n\x13num_trades_to_watch\x18\x02 \x01(\x05:\x01\x30\"+\n\nStockReply\x12\r\n\x05price\x18\x01 \x01(\x02\x12\x0e\n\x06symbol\x18\x02 \x01(\t2\x96\x02\n\x05Stock\x12=\n\x11GetLastTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x12I\n\x19GetLastTradePriceMultiple\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01\x30\x01\x12?\n\x11WatchFutureTrades\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00\x30\x01\x12\x42\n\x14GetHighestTradePrice\x12\x13.stock.StockRequest\x1a\x11.stock.StockReply\"\x00(\x01'
+ ))
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
-
-
-
_STOCKREQUEST = _descriptor.Descriptor(
- name='StockRequest',
- full_name='stock.StockRequest',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='symbol', full_name='stock.StockRequest.symbol', index=0,
- number=1, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- _descriptor.FieldDescriptor(
- name='num_trades_to_watch', full_name='stock.StockRequest.num_trades_to_watch', index=1,
- number=2, type=5, cpp_type=1, label=1,
- has_default_value=True, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- options=None,
- is_extendable=False,
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=22,
- serialized_end=84,
-)
-
+ name='StockRequest',
+ full_name='stock.StockRequest',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='symbol',
+ full_name='stock.StockRequest.symbol',
+ index=0,
+ number=1,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode('utf-8'),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None),
+ _descriptor.FieldDescriptor(
+ name='num_trades_to_watch',
+ full_name='stock.StockRequest.num_trades_to_watch',
+ index=1,
+ number=2,
+ type=5,
+ cpp_type=1,
+ label=1,
+ has_default_value=True,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ options=None,
+ is_extendable=False,
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=22,
+ serialized_end=84,)
_STOCKREPLY = _descriptor.Descriptor(
- name='StockReply',
- full_name='stock.StockReply',
- filename=None,
- file=DESCRIPTOR,
- containing_type=None,
- fields=[
- _descriptor.FieldDescriptor(
- name='price', full_name='stock.StockReply.price', index=0,
- number=1, type=2, cpp_type=6, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- _descriptor.FieldDescriptor(
- name='symbol', full_name='stock.StockReply.symbol', index=1,
- number=2, type=9, cpp_type=9, label=1,
- has_default_value=False, default_value=_b("").decode('utf-8'),
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None,
- options=None),
- ],
- extensions=[
- ],
- nested_types=[],
- enum_types=[
- ],
- options=None,
- is_extendable=False,
- extension_ranges=[],
- oneofs=[
- ],
- serialized_start=86,
- serialized_end=129,
-)
+ name='StockReply',
+ full_name='stock.StockReply',
+ filename=None,
+ file=DESCRIPTOR,
+ containing_type=None,
+ fields=[
+ _descriptor.FieldDescriptor(
+ name='price',
+ full_name='stock.StockReply.price',
+ index=0,
+ number=1,
+ type=2,
+ cpp_type=6,
+ label=1,
+ has_default_value=False,
+ default_value=0,
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None),
+ _descriptor.FieldDescriptor(
+ name='symbol',
+ full_name='stock.StockReply.symbol',
+ index=1,
+ number=2,
+ type=9,
+ cpp_type=9,
+ label=1,
+ has_default_value=False,
+ default_value=_b("").decode('utf-8'),
+ message_type=None,
+ enum_type=None,
+ containing_type=None,
+ is_extension=False,
+ extension_scope=None,
+ options=None),
+ ],
+ extensions=[],
+ nested_types=[],
+ enum_types=[],
+ options=None,
+ is_extendable=False,
+ extension_ranges=[],
+ oneofs=[],
+ serialized_start=86,
+ serialized_end=129,)
DESCRIPTOR.message_types_by_name['StockRequest'] = _STOCKREQUEST
DESCRIPTOR.message_types_by_name['StockReply'] = _STOCKREPLY
-StockRequest = _reflection.GeneratedProtocolMessageType('StockRequest', (_message.Message,), dict(
- DESCRIPTOR = _STOCKREQUEST,
- __module__ = 'stock_pb2'
- # @@protoc_insertion_point(class_scope:stock.StockRequest)
- ))
+StockRequest = _reflection.GeneratedProtocolMessageType(
+ 'StockRequest',
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_STOCKREQUEST,
+ __module__='stock_pb2'
+ # @@protoc_insertion_point(class_scope:stock.StockRequest)
+ ))
_sym_db.RegisterMessage(StockRequest)
-StockReply = _reflection.GeneratedProtocolMessageType('StockReply', (_message.Message,), dict(
- DESCRIPTOR = _STOCKREPLY,
- __module__ = 'stock_pb2'
- # @@protoc_insertion_point(class_scope:stock.StockReply)
- ))
+StockReply = _reflection.GeneratedProtocolMessageType(
+ 'StockReply',
+ (_message.Message,),
+ dict(
+ DESCRIPTOR=_STOCKREPLY,
+ __module__='stock_pb2'
+ # @@protoc_insertion_point(class_scope:stock.StockReply)
+ ))
_sym_db.RegisterMessage(StockReply)
-
# @@protoc_insertion_point(module_scope)
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
index fb3e547781..af2ce64dce 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests application-provided metadata, status code, and details."""
import threading
@@ -53,20 +52,16 @@ _UNARY_STREAM = 'UnaryStream'
_STREAM_UNARY = 'StreamUnary'
_STREAM_STREAM = 'StreamStream'
-_CLIENT_METADATA = (
- ('client-md-key', 'client-md-key'),
- ('client-md-key-bin', b'\x00\x01')
-)
+_CLIENT_METADATA = (('client-md-key', 'client-md-key'),
+ ('client-md-key-bin', b'\x00\x01'))
_SERVER_INITIAL_METADATA = (
('server-initial-md-key', 'server-initial-md-value'),
- ('server-initial-md-key-bin', b'\x00\x02')
-)
+ ('server-initial-md-key-bin', b'\x00\x02'))
_SERVER_TRAILING_METADATA = (
('server-trailing-md-key', 'server-trailing-md-value'),
- ('server-trailing-md-key-bin', b'\x00\x03')
-)
+ ('server-trailing-md-key-bin', b'\x00\x03'))
_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
_DETAILS = 'Test details!'
@@ -74,450 +69,464 @@ _DETAILS = 'Test details!'
class _Servicer(object):
- def __init__(self):
- self._lock = threading.Lock()
- self._code = None
- self._details = None
- self._exception = False
- self._return_none = False
- self._received_client_metadata = None
-
- def unary_unary(self, request, context):
- with self._lock:
- self._received_client_metadata = context.invocation_metadata()
- context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
- if self._exception:
- raise test_control.Defect()
- else:
- return None if self._return_none else object()
-
- def unary_stream(self, request, context):
- with self._lock:
- self._received_client_metadata = context.invocation_metadata()
- context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
- for _ in range(test_constants.STREAM_LENGTH // 2):
- yield _SERIALIZED_RESPONSE
- if self._exception:
- raise test_control.Defect()
-
- def stream_unary(self, request_iterator, context):
- with self._lock:
- self._received_client_metadata = context.invocation_metadata()
- context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
- # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
- # request iterator.
- for ignored_request in request_iterator:
- pass
- if self._exception:
- raise test_control.Defect()
- else:
- return None if self._return_none else _SERIALIZED_RESPONSE
-
- def stream_stream(self, request_iterator, context):
- with self._lock:
- self._received_client_metadata = context.invocation_metadata()
- context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- if self._code is not None:
- context.set_code(self._code)
- if self._details is not None:
- context.set_details(self._details)
- # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
- # request iterator.
- for ignored_request in request_iterator:
- pass
- for _ in range(test_constants.STREAM_LENGTH // 3):
- yield object()
- if self._exception:
- raise test_control.Defect()
-
- def set_code(self, code):
- with self._lock:
- self._code = code
-
- def set_details(self, details):
- with self._lock:
- self._details = details
-
- def set_exception(self):
- with self._lock:
- self._exception = True
-
- def set_return_none(self):
- with self._lock:
- self._return_none = True
-
- def received_client_metadata(self):
- with self._lock:
- return self._received_client_metadata
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._code = None
+ self._details = None
+ self._exception = False
+ self._return_none = False
+ self._received_client_metadata = None
+
+ def unary_unary(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else object()
+
+ def unary_stream(self, request, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ yield _SERIALIZED_RESPONSE
+ if self._exception:
+ raise test_control.Defect()
+
+ def stream_unary(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ if self._exception:
+ raise test_control.Defect()
+ else:
+ return None if self._return_none else _SERIALIZED_RESPONSE
+
+ def stream_stream(self, request_iterator, context):
+ with self._lock:
+ self._received_client_metadata = context.invocation_metadata()
+ context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ if self._code is not None:
+ context.set_code(self._code)
+ if self._details is not None:
+ context.set_details(self._details)
+ # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
+ # request iterator.
+ for ignored_request in request_iterator:
+ pass
+ for _ in range(test_constants.STREAM_LENGTH // 3):
+ yield object()
+ if self._exception:
+ raise test_control.Defect()
+
+ def set_code(self, code):
+ with self._lock:
+ self._code = code
+
+ def set_details(self, details):
+ with self._lock:
+ self._details = details
+
+ def set_exception(self):
+ with self._lock:
+ self._exception = True
+
+ def set_return_none(self):
+ with self._lock:
+ self._return_none = True
+
+ def received_client_metadata(self):
+ with self._lock:
+ return self._received_client_metadata
def _generic_handler(servicer):
- method_handlers = {
- _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
- servicer.unary_unary, request_deserializer=_REQUEST_DESERIALIZER,
- response_serializer=_RESPONSE_SERIALIZER),
- _UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
- servicer.unary_stream),
- _STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
- servicer.stream_unary),
- _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
- servicer.stream_stream, request_deserializer=_REQUEST_DESERIALIZER,
- response_serializer=_RESPONSE_SERIALIZER),
- }
- return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)
+ method_handlers = {
+ _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
+ servicer.unary_unary,
+ request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ _UNARY_STREAM:
+ grpc.unary_stream_rpc_method_handler(servicer.unary_stream),
+ _STREAM_UNARY:
+ grpc.stream_unary_rpc_method_handler(servicer.stream_unary),
+ _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
+ servicer.stream_stream,
+ request_deserializer=_REQUEST_DESERIALIZER,
+ response_serializer=_RESPONSE_SERIALIZER),
+ }
+ return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)
class MetadataCodeDetailsTest(unittest.TestCase):
- def setUp(self):
- self._servicer = _Servicer()
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(
- self._server_pool, handlers=(_generic_handler(self._servicer),))
- port = self._server.add_insecure_port('[::]:0')
- self._server.start()
-
- channel = grpc.insecure_channel('localhost:{}'.format(port))
- self._unary_unary = channel.unary_unary(
- '/'.join(('', _SERVICE, _UNARY_UNARY,)),
- request_serializer=_REQUEST_SERIALIZER,
- response_deserializer=_RESPONSE_DESERIALIZER,)
- self._unary_stream = channel.unary_stream(
- '/'.join(('', _SERVICE, _UNARY_STREAM,)),)
- self._stream_unary = channel.stream_unary(
- '/'.join(('', _SERVICE, _STREAM_UNARY,)),)
- self._stream_stream = channel.stream_stream(
- '/'.join(('', _SERVICE, _STREAM_STREAM,)),
- request_serializer=_REQUEST_SERIALIZER,
- response_deserializer=_RESPONSE_DESERIALIZER,)
-
-
- def testSuccessfulUnaryUnary(self):
- self._servicer.set_details(_DETAILS)
-
- unused_response, call = self._unary_unary.with_call(
- object(), metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testSuccessfulUnaryStream(self):
- self._servicer.set_details(_DETAILS)
-
- call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testSuccessfulStreamUnary(self):
- self._servicer.set_details(_DETAILS)
-
- unused_response, call = self._stream_unary.with_call(
- iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testSuccessfulStreamStream(self):
- self._servicer.set_details(_DETAILS)
-
- call = self._stream_stream(
- iter([object()] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(grpc.StatusCode.OK, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testCustomCodeUnaryUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeUnaryStream(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
-
- call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testCustomCodeStreamUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._stream_unary.with_call(
- iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeStreamStream(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
-
- call = self._stream_stream(
- iter([object()] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- with self.assertRaises(grpc.RpcError) as exception_context:
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeExceptionUnaryUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_exception()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeExceptionUnaryStream(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_exception()
-
- call = self._unary_stream(_SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testCustomCodeExceptionStreamUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_exception()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._stream_unary.with_call(
- iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeExceptionStreamStream(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_exception()
-
- call = self._stream_stream(
- iter([object()] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
- received_initial_metadata = call.initial_metadata()
- with self.assertRaises(grpc.RpcError):
- for _ in call:
- pass
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, received_initial_metadata))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, call.code())
- self.assertEqual(_DETAILS, call.details())
-
- def testCustomCodeReturnNoneUnaryUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_return_none()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
-
- def testCustomCodeReturnNoneStreamUnary(self):
- self._servicer.set_code(_NON_OK_CODE)
- self._servicer.set_details(_DETAILS)
- self._servicer.set_return_none()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._stream_unary.with_call(
- iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
-
- self.assertTrue(
- test_common.metadata_transmitted(
- _CLIENT_METADATA, self._servicer.received_client_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA,
- exception_context.exception.initial_metadata()))
- self.assertTrue(
- test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA,
- exception_context.exception.trailing_metadata()))
- self.assertIs(_NON_OK_CODE, exception_context.exception.code())
- self.assertEqual(_DETAILS, exception_context.exception.details())
+ def setUp(self):
+ self._servicer = _Servicer()
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ self._server_pool, handlers=(_generic_handler(self._servicer),))
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+
+ channel = grpc.insecure_channel('localhost:{}'.format(port))
+ self._unary_unary = channel.unary_unary(
+ '/'.join((
+ '',
+ _SERVICE,
+ _UNARY_UNARY,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+ self._unary_stream = channel.unary_stream('/'.join((
+ '',
+ _SERVICE,
+ _UNARY_STREAM,)),)
+ self._stream_unary = channel.stream_unary('/'.join((
+ '',
+ _SERVICE,
+ _STREAM_UNARY,)),)
+ self._stream_stream = channel.stream_stream(
+ '/'.join((
+ '',
+ _SERVICE,
+ _STREAM_STREAM,)),
+ request_serializer=_REQUEST_SERIALIZER,
+ response_deserializer=_RESPONSE_DESERIALIZER,)
+
+ def testSuccessfulUnaryUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._unary_unary.with_call(
+ object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulUnaryStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(
+ _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamUnary(self):
+ self._servicer.set_details(_DETAILS)
+
+ unused_response, call = self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testSuccessfulStreamStream(self):
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(grpc.StatusCode.OK, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._unary_stream(
+ _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionUnaryStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._unary_stream(
+ _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeExceptionStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeExceptionStreamStream(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_exception()
+
+ call = self._stream_stream(
+ iter([object()] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ received_initial_metadata = call.initial_metadata()
+ with self.assertRaises(grpc.RpcError):
+ for _ in call:
+ pass
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ received_initial_metadata))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, call.code())
+ self.assertEqual(_DETAILS, call.details())
+
+ def testCustomCodeReturnNoneUnaryUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
+
+ def testCustomCodeReturnNoneStreamUnary(self):
+ self._servicer.set_code(_NON_OK_CODE)
+ self._servicer.set_details(_DETAILS)
+ self._servicer.set_return_none()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._stream_unary.with_call(
+ iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, self._servicer.received_client_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_INITIAL_METADATA,
+ exception_context.exception.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(
+ _SERVER_TRAILING_METADATA,
+ exception_context.exception.trailing_metadata()))
+ self.assertIs(_NON_OK_CODE, exception_context.exception.code())
+ self.assertEqual(_DETAILS, exception_context.exception.details())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_metadata_test.py b/src/python/grpcio_tests/tests/unit/_metadata_test.py
index caba53ffcc..53fe7ba8aa 100644
--- a/src/python/grpcio_tests/tests/unit/_metadata_test.py
+++ b/src/python/grpcio_tests/tests/unit/_metadata_test.py
@@ -51,166 +51,174 @@ _STREAM_STREAM = '/test/StreamStream'
_USER_AGENT = 'Python-gRPC-{}'.format(_grpcio_metadata.__version__)
-_CLIENT_METADATA = (
- ('client-md-key', 'client-md-key'),
- ('client-md-key-bin', b'\x00\x01')
-)
+_CLIENT_METADATA = (('client-md-key', 'client-md-key'),
+ ('client-md-key-bin', b'\x00\x01'))
_SERVER_INITIAL_METADATA = (
('server-initial-md-key', 'server-initial-md-value'),
- ('server-initial-md-key-bin', b'\x00\x02')
-)
+ ('server-initial-md-key-bin', b'\x00\x02'))
_SERVER_TRAILING_METADATA = (
('server-trailing-md-key', 'server-trailing-md-value'),
- ('server-trailing-md-key-bin', b'\x00\x03')
-)
+ ('server-trailing-md-key-bin', b'\x00\x03'))
def user_agent(metadata):
- for key, val in metadata:
- if key == 'user-agent':
- return val
- raise KeyError('No user agent!')
+ for key, val in metadata:
+ if key == 'user-agent':
+ return val
+ raise KeyError('No user agent!')
def validate_client_metadata(test, servicer_context):
- test.assertTrue(test_common.metadata_transmitted(
- _CLIENT_METADATA, servicer_context.invocation_metadata()))
- test.assertTrue(user_agent(servicer_context.invocation_metadata())
- .startswith('primary-agent ' + _USER_AGENT))
- test.assertTrue(user_agent(servicer_context.invocation_metadata())
- .endswith('secondary-agent'))
+ test.assertTrue(
+ test_common.metadata_transmitted(
+ _CLIENT_METADATA, servicer_context.invocation_metadata()))
+ test.assertTrue(
+ user_agent(servicer_context.invocation_metadata())
+ .startswith('primary-agent ' + _USER_AGENT))
+ test.assertTrue(
+ user_agent(servicer_context.invocation_metadata())
+ .endswith('secondary-agent'))
def handle_unary_unary(test, request, servicer_context):
- validate_client_metadata(test, servicer_context)
- servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- return _RESPONSE
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ return _RESPONSE
def handle_unary_stream(test, request, servicer_context):
- validate_client_metadata(test, servicer_context)
- servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- for _ in range(test_constants.STREAM_LENGTH):
- yield _RESPONSE
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ for _ in range(test_constants.STREAM_LENGTH):
+ yield _RESPONSE
def handle_stream_unary(test, request_iterator, servicer_context):
- validate_client_metadata(test, servicer_context)
- servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- # TODO(issue:#6891) We should be able to remove this loop
- for request in request_iterator:
- pass
- return _RESPONSE
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop
+ for request in request_iterator:
+ pass
+ return _RESPONSE
def handle_stream_stream(test, request_iterator, servicer_context):
- validate_client_metadata(test, servicer_context)
- servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
- servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
- # TODO(issue:#6891) We should be able to remove this loop,
- # and replace with return; yield
- for request in request_iterator:
- yield _RESPONSE
+ validate_client_metadata(test, servicer_context)
+ servicer_context.send_initial_metadata(_SERVER_INITIAL_METADATA)
+ servicer_context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
+ # TODO(issue:#6891) We should be able to remove this loop,
+ # and replace with return; yield
+ for request in request_iterator:
+ yield _RESPONSE
class _MethodHandler(grpc.RpcMethodHandler):
- def __init__(self, test, request_streaming, response_streaming):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = None
- self.response_serializer = None
- self.unary_unary = None
- self.unary_stream = None
- self.stream_unary = None
- self.stream_stream = None
- if self.request_streaming and self.response_streaming:
- self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
- elif self.request_streaming:
- self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
- elif self.response_streaming:
- self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
- else:
- self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
+ def __init__(self, test, request_streaming, response_streaming):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = None
+ self.response_serializer = None
+ self.unary_unary = None
+ self.unary_stream = None
+ self.stream_unary = None
+ self.stream_stream = None
+ if self.request_streaming and self.response_streaming:
+ self.stream_stream = lambda x, y: handle_stream_stream(test, x, y)
+ elif self.request_streaming:
+ self.stream_unary = lambda x, y: handle_stream_unary(test, x, y)
+ elif self.response_streaming:
+ self.unary_stream = lambda x, y: handle_unary_stream(test, x, y)
+ else:
+ self.unary_unary = lambda x, y: handle_unary_unary(test, x, y)
class _GenericHandler(grpc.GenericRpcHandler):
- def __init__(self, test):
- self._test = test
+ def __init__(self, test):
+ self._test = test
- def service(self, handler_call_details):
- if handler_call_details.method == _UNARY_UNARY:
- return _MethodHandler(self._test, False, False)
- elif handler_call_details.method == _UNARY_STREAM:
- return _MethodHandler(self._test, False, True)
- elif handler_call_details.method == _STREAM_UNARY:
- return _MethodHandler(self._test, True, False)
- elif handler_call_details.method == _STREAM_STREAM:
- return _MethodHandler(self._test, True, True)
- else:
- return None
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(self._test, False, False)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(self._test, False, True)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(self._test, True, False)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(self._test, True, True)
+ else:
+ return None
class MetadataTest(unittest.TestCase):
- def setUp(self):
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(
- self._server_pool, handlers=(_GenericHandler(weakref.proxy(self)),))
- port = self._server.add_insecure_port('[::]:0')
- self._server.start()
- self._channel = grpc.insecure_channel('localhost:%d' % port,
- options=_CHANNEL_ARGS)
-
- def tearDown(self):
- self._server.stop(0)
-
- def testUnaryUnary(self):
- multi_callable = self._channel.unary_unary(_UNARY_UNARY)
- unused_response, call = multi_callable.with_call(
- _REQUEST, metadata=_CLIENT_METADATA)
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
-
- def testUnaryStream(self):
- multi_callable = self._channel.unary_stream(_UNARY_STREAM)
- call = multi_callable(_REQUEST, metadata=_CLIENT_METADATA)
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- for _ in call:
- pass
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
-
- def testStreamUnary(self):
- multi_callable = self._channel.stream_unary(_STREAM_UNARY)
- unused_response, call = multi_callable.with_call(
- iter([_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
-
- def testStreamStream(self):
- multi_callable = self._channel.stream_stream(_STREAM_STREAM)
- call = multi_callable(iter([_REQUEST] * test_constants.STREAM_LENGTH),
- metadata=_CLIENT_METADATA)
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_INITIAL_METADATA, call.initial_metadata()))
- for _ in call:
- pass
- self.assertTrue(test_common.metadata_transmitted(
- _SERVER_TRAILING_METADATA, call.trailing_metadata()))
+ def setUp(self):
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ self._server = grpc.server(
+ self._server_pool, handlers=(_GenericHandler(weakref.proxy(self)),))
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.start()
+ self._channel = grpc.insecure_channel(
+ 'localhost:%d' % port, options=_CHANNEL_ARGS)
+
+ def tearDown(self):
+ self._server.stop(0)
+
+ def testUnaryUnary(self):
+ multi_callable = self._channel.unary_unary(_UNARY_UNARY)
+ unused_response, call = multi_callable.with_call(
+ _REQUEST, metadata=_CLIENT_METADATA)
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+
+ def testUnaryStream(self):
+ multi_callable = self._channel.unary_stream(_UNARY_STREAM)
+ call = multi_callable(_REQUEST, metadata=_CLIENT_METADATA)
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+
+ def testStreamUnary(self):
+ multi_callable = self._channel.stream_unary(_STREAM_UNARY)
+ unused_response, call = multi_callable.with_call(
+ iter([_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
+
+ def testStreamStream(self):
+ multi_callable = self._channel.stream_stream(_STREAM_STREAM)
+ call = multi_callable(
+ iter([_REQUEST] * test_constants.STREAM_LENGTH),
+ metadata=_CLIENT_METADATA)
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_INITIAL_METADATA,
+ call.initial_metadata()))
+ for _ in call:
+ pass
+ self.assertTrue(
+ test_common.metadata_transmitted(_SERVER_TRAILING_METADATA,
+ call.trailing_metadata()))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_rpc_test.py b/src/python/grpcio_tests/tests/unit/_rpc_test.py
index eb00156da5..2cf6dfea62 100644
--- a/src/python/grpcio_tests/tests/unit/_rpc_test.py
+++ b/src/python/grpcio_tests/tests/unit/_rpc_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test of RPCs made against gRPC Python's application-layer API."""
import itertools
@@ -53,742 +52,797 @@ _STREAM_STREAM = '/test/StreamStream'
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._value = None
- self._called = False
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
+ self._called = False
- def __call__(self, value):
- with self._condition:
- self._value = value
- self._called = True
- self._condition.notify_all()
+ def __call__(self, value):
+ with self._condition:
+ self._value = value
+ self._called = True
+ self._condition.notify_all()
- def value(self):
- with self._condition:
- while not self._called:
- self._condition.wait()
- return self._value
+ def value(self):
+ with self._condition:
+ while not self._called:
+ self._condition.wait()
+ return self._value
class _Handler(object):
- def __init__(self, control):
- self._control = control
-
- def handle_unary_unary(self, request, servicer_context):
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- return request
-
- def handle_unary_stream(self, request, servicer_context):
- for _ in range(test_constants.STREAM_LENGTH):
- self._control.control()
- yield request
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
-
- def handle_stream_unary(self, request_iterator, servicer_context):
- if servicer_context is not None:
- servicer_context.invocation_metadata()
- self._control.control()
- response_elements = []
- for request in request_iterator:
- self._control.control()
- response_elements.append(request)
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- return b''.join(response_elements)
-
- def handle_stream_stream(self, request_iterator, servicer_context):
- self._control.control()
- if servicer_context is not None:
- servicer_context.set_trailing_metadata((('testkey', 'testvalue',),))
- for request in request_iterator:
- self._control.control()
- yield request
- self._control.control()
+ def __init__(self, control):
+ self._control = control
+
+ def handle_unary_unary(self, request, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ return request
+
+ def handle_unary_stream(self, request, servicer_context):
+ for _ in range(test_constants.STREAM_LENGTH):
+ self._control.control()
+ yield request
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+
+ def handle_stream_unary(self, request_iterator, servicer_context):
+ if servicer_context is not None:
+ servicer_context.invocation_metadata()
+ self._control.control()
+ response_elements = []
+ for request in request_iterator:
+ self._control.control()
+ response_elements.append(request)
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ return b''.join(response_elements)
+
+ def handle_stream_stream(self, request_iterator, servicer_context):
+ self._control.control()
+ if servicer_context is not None:
+ servicer_context.set_trailing_metadata(((
+ 'testkey',
+ 'testvalue',),))
+ for request in request_iterator:
+ self._control.control()
+ yield request
+ self._control.control()
class _MethodHandler(grpc.RpcMethodHandler):
- def __init__(
- self, request_streaming, response_streaming, request_deserializer,
- response_serializer, unary_unary, unary_stream, stream_unary,
- stream_stream):
- self.request_streaming = request_streaming
- self.response_streaming = response_streaming
- self.request_deserializer = request_deserializer
- self.response_serializer = response_serializer
- self.unary_unary = unary_unary
- self.unary_stream = unary_stream
- self.stream_unary = stream_unary
- self.stream_stream = stream_stream
+ def __init__(self, request_streaming, response_streaming,
+ request_deserializer, response_serializer, unary_unary,
+ unary_stream, stream_unary, stream_stream):
+ self.request_streaming = request_streaming
+ self.response_streaming = response_streaming
+ self.request_deserializer = request_deserializer
+ self.response_serializer = response_serializer
+ self.unary_unary = unary_unary
+ self.unary_stream = unary_stream
+ self.stream_unary = stream_unary
+ self.stream_stream = stream_stream
class _GenericHandler(grpc.GenericRpcHandler):
- def __init__(self, handler):
- self._handler = handler
-
- def service(self, handler_call_details):
- if handler_call_details.method == _UNARY_UNARY:
- return _MethodHandler(
- False, False, None, None, self._handler.handle_unary_unary, None,
- None, None)
- elif handler_call_details.method == _UNARY_STREAM:
- return _MethodHandler(
- False, True, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None,
- self._handler.handle_unary_stream, None, None)
- elif handler_call_details.method == _STREAM_UNARY:
- return _MethodHandler(
- True, False, _DESERIALIZE_REQUEST, _SERIALIZE_RESPONSE, None, None,
- self._handler.handle_stream_unary, None)
- elif handler_call_details.method == _STREAM_STREAM:
- return _MethodHandler(
- True, True, None, None, None, None, None,
- self._handler.handle_stream_stream)
- else:
- return None
+ def __init__(self, handler):
+ self._handler = handler
+
+ def service(self, handler_call_details):
+ if handler_call_details.method == _UNARY_UNARY:
+ return _MethodHandler(False, False, None, None,
+ self._handler.handle_unary_unary, None, None,
+ None)
+ elif handler_call_details.method == _UNARY_STREAM:
+ return _MethodHandler(False, True, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None,
+ self._handler.handle_unary_stream, None, None)
+ elif handler_call_details.method == _STREAM_UNARY:
+ return _MethodHandler(True, False, _DESERIALIZE_REQUEST,
+ _SERIALIZE_RESPONSE, None, None,
+ self._handler.handle_stream_unary, None)
+ elif handler_call_details.method == _STREAM_STREAM:
+ return _MethodHandler(True, True, None, None, None, None, None,
+ self._handler.handle_stream_stream)
+ else:
+ return None
def _unary_unary_multi_callable(channel):
- return channel.unary_unary(_UNARY_UNARY)
+ return channel.unary_unary(_UNARY_UNARY)
def _unary_stream_multi_callable(channel):
- return channel.unary_stream(
- _UNARY_STREAM,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.unary_stream(
+ _UNARY_STREAM,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_unary_multi_callable(channel):
- return channel.stream_unary(
- _STREAM_UNARY,
- request_serializer=_SERIALIZE_REQUEST,
- response_deserializer=_DESERIALIZE_RESPONSE)
+ return channel.stream_unary(
+ _STREAM_UNARY,
+ request_serializer=_SERIALIZE_REQUEST,
+ response_deserializer=_DESERIALIZE_RESPONSE)
def _stream_stream_multi_callable(channel):
- return channel.stream_stream(_STREAM_STREAM)
+ return channel.stream_stream(_STREAM_STREAM)
class RPCTest(unittest.TestCase):
- def setUp(self):
- self._control = test_control.PauseFailControl()
- self._handler = _Handler(self._control)
- self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ def setUp(self):
+ self._control = test_control.PauseFailControl()
+ self._handler = _Handler(self._control)
+ self._server_pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- self._server = grpc.server(self._server_pool)
- port = self._server.add_insecure_port('[::]:0')
- self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
- self._server.start()
+ self._server = grpc.server(self._server_pool)
+ port = self._server.add_insecure_port('[::]:0')
+ self._server.add_generic_rpc_handlers((_GenericHandler(self._handler),))
+ self._server.start()
- self._channel = grpc.insecure_channel('localhost:%d' % port)
-
- def tearDown(self):
- self._server.stop(None)
- self._server_pool.shutdown(wait=True)
-
- def testUnrecognizedMethod(self):
- request = b'abc'
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- self._channel.unary_unary('NoSuchMethod')(request)
-
- self.assertEqual(
- grpc.StatusCode.UNIMPLEMENTED, exception_context.exception.code())
-
- def testSuccessfulUnaryRequestBlockingUnaryResponse(self):
- request = b'\x07\x08'
- expected_response = self._handler.handle_unary_unary(request, None)
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- response = multi_callable(
- request, metadata=(
- ('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),))
-
- self.assertEqual(expected_response, response)
-
- def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self):
- request = b'\x07\x08'
- expected_response = self._handler.handle_unary_unary(request, None)
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- response, call = multi_callable.with_call(
- request, metadata=(
- ('test', 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),))
-
- self.assertEqual(expected_response, response)
- self.assertIs(grpc.StatusCode.OK, call.code())
-
- def testSuccessfulUnaryRequestFutureUnaryResponse(self):
- request = b'\x07\x08'
- expected_response = self._handler.handle_unary_unary(request, None)
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- response_future = multi_callable.future(
- request, metadata=(
- ('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),))
- response = response_future.result()
-
- self.assertIsInstance(response_future, grpc.Future)
- self.assertIsInstance(response_future, grpc.Call)
- self.assertEqual(expected_response, response)
- self.assertIsNone(response_future.exception())
- self.assertIsNone(response_future.traceback())
-
- def testSuccessfulUnaryRequestStreamResponse(self):
- request = b'\x37\x58'
- expected_responses = tuple(self._handler.handle_unary_stream(request, None))
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request,
- metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),))
- responses = tuple(response_iterator)
-
- self.assertSequenceEqual(expected_responses, responses)
-
- def testSuccessfulStreamRequestBlockingUnaryResponse(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- expected_response = self._handler.handle_stream_unary(iter(requests), None)
- request_iterator = iter(requests)
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- response = multi_callable(
- request_iterator,
- metadata=(('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),))
-
- self.assertEqual(expected_response, response)
-
- def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- expected_response = self._handler.handle_stream_unary(iter(requests), None)
- request_iterator = iter(requests)
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- response, call = multi_callable.with_call(
- request_iterator,
- metadata=(
- ('test', 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),
- ))
-
- self.assertEqual(expected_response, response)
- self.assertIs(grpc.StatusCode.OK, call.code())
-
- def testSuccessfulStreamRequestFutureUnaryResponse(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- expected_response = self._handler.handle_stream_unary(iter(requests), None)
- request_iterator = iter(requests)
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- response_future = multi_callable.future(
- request_iterator,
- metadata=(
- ('test', 'SuccessfulStreamRequestFutureUnaryResponse'),))
- response = response_future.result()
-
- self.assertEqual(expected_response, response)
- self.assertIsNone(response_future.exception())
- self.assertIsNone(response_future.traceback())
-
- def testSuccessfulStreamRequestStreamResponse(self):
- requests = tuple(b'\x77\x58' for _ in range(test_constants.STREAM_LENGTH))
- expected_responses = tuple(
- self._handler.handle_stream_stream(iter(requests), None))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request_iterator,
- metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),))
- responses = tuple(response_iterator)
-
- self.assertSequenceEqual(expected_responses, responses)
-
- def testSequentialInvocations(self):
- first_request = b'\x07\x08'
- second_request = b'\x0809'
- expected_first_response = self._handler.handle_unary_unary(
- first_request, None)
- expected_second_response = self._handler.handle_unary_unary(
- second_request, None)
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- first_response = multi_callable(
- first_request, metadata=(('test', 'SequentialInvocations'),))
- second_response = multi_callable(
- second_request, metadata=(('test', 'SequentialInvocations'),))
-
- self.assertEqual(expected_first_response, first_response)
- self.assertEqual(expected_second_response, second_response)
-
- def testConcurrentBlockingInvocations(self):
- pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- expected_response = self._handler.handle_stream_unary(iter(requests), None)
- expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
- response_futures = [None] * test_constants.THREAD_CONCURRENCY
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- for index in range(test_constants.THREAD_CONCURRENCY):
- request_iterator = iter(requests)
- response_future = pool.submit(
- multi_callable, request_iterator,
- metadata=(('test', 'ConcurrentBlockingInvocations'),))
- response_futures[index] = response_future
- responses = tuple(
- response_future.result() for response_future in response_futures)
-
- pool.shutdown(wait=True)
- self.assertSequenceEqual(expected_responses, responses)
-
- def testConcurrentFutureInvocations(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- expected_response = self._handler.handle_stream_unary(iter(requests), None)
- expected_responses = [expected_response] * test_constants.THREAD_CONCURRENCY
- response_futures = [None] * test_constants.THREAD_CONCURRENCY
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- for index in range(test_constants.THREAD_CONCURRENCY):
- request_iterator = iter(requests)
- response_future = multi_callable.future(
- request_iterator,
- metadata=(('test', 'ConcurrentFutureInvocations'),))
- response_futures[index] = response_future
- responses = tuple(
- response_future.result() for response_future in response_futures)
-
- self.assertSequenceEqual(expected_responses, responses)
-
- def testWaitingForSomeButNotAllConcurrentFutureInvocations(self):
- pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- request = b'\x67\x68'
- expected_response = self._handler.handle_unary_unary(request, None)
- response_futures = [None] * test_constants.THREAD_CONCURRENCY
- lock = threading.Lock()
- test_is_running_cell = [True]
- def wrap_future(future):
- def wrap():
- try:
- return future.result()
- except grpc.RpcError:
- with lock:
- if test_is_running_cell[0]:
- raise
- return None
- return wrap
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- for index in range(test_constants.THREAD_CONCURRENCY):
- inner_response_future = multi_callable.future(
- request,
- metadata=(
- ('test',
- 'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
- outer_response_future = pool.submit(wrap_future(inner_response_future))
- response_futures[index] = outer_response_future
-
- some_completed_response_futures_iterator = itertools.islice(
- futures.as_completed(response_futures),
- test_constants.THREAD_CONCURRENCY // 2)
- for response_future in some_completed_response_futures_iterator:
- self.assertEqual(expected_response, response_future.result())
- with lock:
- test_is_running_cell[0] = False
-
- def testConsumingOneStreamResponseUnaryRequest(self):
- request = b'\x57\x38'
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request,
- metadata=(
- ('test', 'ConsumingOneStreamResponseUnaryRequest'),))
- next(response_iterator)
-
- def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
- request = b'\x57\x38'
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request,
- metadata=(
- ('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
- for _ in range(test_constants.STREAM_LENGTH // 2):
- next(response_iterator)
-
- def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
- requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request_iterator,
- metadata=(
- ('test', 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
- for _ in range(test_constants.STREAM_LENGTH // 2):
- next(response_iterator)
-
- def testConsumingTooManyStreamResponsesStreamRequest(self):
- requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- response_iterator = multi_callable(
- request_iterator,
- metadata=(
- ('test', 'ConsumingTooManyStreamResponsesStreamRequest'),))
- for _ in range(test_constants.STREAM_LENGTH):
- next(response_iterator)
- for _ in range(test_constants.STREAM_LENGTH):
- with self.assertRaises(StopIteration):
- next(response_iterator)
+ self._channel = grpc.insecure_channel('localhost:%d' % port)
- self.assertIsNotNone(response_iterator.initial_metadata())
- self.assertIs(grpc.StatusCode.OK, response_iterator.code())
- self.assertIsNotNone(response_iterator.details())
- self.assertIsNotNone(response_iterator.trailing_metadata())
-
- def testCancelledUnaryRequestUnaryResponse(self):
- request = b'\x07\x17'
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- with self._control.pause():
- response_future = multi_callable.future(
- request,
- metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),))
- response_future.cancel()
-
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.result()
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.exception()
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.traceback()
- self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
-
- def testCancelledUnaryRequestStreamResponse(self):
- request = b'\x07\x19'
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- with self._control.pause():
- response_iterator = multi_callable(
- request,
- metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
- self._control.block_until_paused()
- response_iterator.cancel()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- next(response_iterator)
- self.assertIs(grpc.StatusCode.CANCELLED, exception_context.exception.code())
- self.assertIsNotNone(response_iterator.initial_metadata())
- self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
- self.assertIsNotNone(response_iterator.details())
- self.assertIsNotNone(response_iterator.trailing_metadata())
-
- def testCancelledStreamRequestUnaryResponse(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self._control.pause():
- response_future = multi_callable.future(
- request_iterator,
- metadata=(('test', 'CancelledStreamRequestUnaryResponse'),))
- self._control.block_until_paused()
- response_future.cancel()
-
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.result()
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.exception()
- with self.assertRaises(grpc.FutureCancelledError):
- response_future.traceback()
- self.assertIsNotNone(response_future.initial_metadata())
- self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
- self.assertIsNotNone(response_future.details())
- self.assertIsNotNone(response_future.trailing_metadata())
-
- def testCancelledStreamRequestStreamResponse(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- with self._control.pause():
- response_iterator = multi_callable(
- request_iterator,
- metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
- response_iterator.cancel()
-
- with self.assertRaises(grpc.RpcError):
- next(response_iterator)
- self.assertIsNotNone(response_iterator.initial_metadata())
- self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
- self.assertIsNotNone(response_iterator.details())
- self.assertIsNotNone(response_iterator.trailing_metadata())
-
- def testExpiredUnaryRequestBlockingUnaryResponse(self):
- request = b'\x07\x17'
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- with self._control.pause():
- with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable.with_call(
- request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),))
-
- self.assertIsInstance(exception_context.exception, grpc.Call)
- self.assertIsNotNone(exception_context.exception.initial_metadata())
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIsNotNone(exception_context.exception.details())
- self.assertIsNotNone(exception_context.exception.trailing_metadata())
-
- def testExpiredUnaryRequestFutureUnaryResponse(self):
- request = b'\x07\x17'
- callback = _Callback()
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- with self._control.pause():
- response_future = multi_callable.future(
- request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),))
- response_future.add_done_callback(callback)
- value_passed_to_callback = callback.value()
-
- self.assertIs(response_future, value_passed_to_callback)
- self.assertIsNotNone(response_future.initial_metadata())
- self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
- self.assertIsNotNone(response_future.details())
- self.assertIsNotNone(response_future.trailing_metadata())
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIsInstance(response_future.exception(), grpc.RpcError)
- self.assertIsNotNone(response_future.traceback())
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, response_future.exception().code())
-
- def testExpiredUnaryRequestStreamResponse(self):
- request = b'\x07\x19'
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- with self._control.pause():
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_iterator = multi_callable(
- request, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
- next(response_iterator)
+ def tearDown(self):
+ self._server.stop(None)
+ self._server_pool.shutdown(wait=True)
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
+ def testUnrecognizedMethod(self):
+ request = b'abc'
- def testExpiredStreamRequestBlockingUnaryResponse(self):
- requests = tuple(b'\x07\x08' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ self._channel.unary_unary('NoSuchMethod')(request)
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self._control.pause():
- with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable(
- request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredStreamRequestBlockingUnaryResponse'),))
-
- self.assertIsInstance(exception_context.exception, grpc.RpcError)
- self.assertIsInstance(exception_context.exception, grpc.Call)
- self.assertIsNotNone(exception_context.exception.initial_metadata())
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIsNotNone(exception_context.exception.details())
- self.assertIsNotNone(exception_context.exception.trailing_metadata())
-
- def testExpiredStreamRequestFutureUnaryResponse(self):
- requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
- callback = _Callback()
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self._control.pause():
- response_future = multi_callable.future(
- request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),))
- with self.assertRaises(grpc.FutureTimeoutError):
- response_future.result(timeout=test_constants.SHORT_TIMEOUT / 2.0)
- response_future.add_done_callback(callback)
- value_passed_to_callback = callback.value()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIsInstance(response_future.exception(), grpc.RpcError)
- self.assertIsNotNone(response_future.traceback())
- self.assertIs(response_future, value_passed_to_callback)
- self.assertIsNotNone(response_future.initial_metadata())
- self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
- self.assertIsNotNone(response_future.details())
- self.assertIsNotNone(response_future.trailing_metadata())
-
- def testExpiredStreamRequestStreamResponse(self):
- requests = tuple(b'\x67\x18' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- with self._control.pause():
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_iterator = multi_callable(
- request_iterator, timeout=test_constants.SHORT_TIMEOUT,
- metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
- next(response_iterator)
+ self.assertEqual(grpc.StatusCode.UNIMPLEMENTED,
+ exception_context.exception.code())
+
+ def testSuccessfulUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response = multi_callable(
+ request,
+ metadata=(('test', 'SuccessfulUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertEqual(expected_response, response)
- self.assertIs(
- grpc.StatusCode.DEADLINE_EXCEEDED, exception_context.exception.code())
- self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_iterator.code())
+ def testSuccessfulUnaryRequestBlockingUnaryResponseWithCall(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
- def testFailedUnaryRequestBlockingUnaryResponse(self):
- request = b'\x37\x17'
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response, call = multi_callable.with_call(
+ request,
+ metadata=(('test',
+ 'SuccessfulUnaryRequestBlockingUnaryResponseWithCall'),))
+
+ self.assertEqual(expected_response, response)
+ self.assertIs(grpc.StatusCode.OK, call.code())
+
+ def testSuccessfulUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x08'
+ expected_response = self._handler.handle_unary_unary(request, None)
- multi_callable = _unary_unary_multi_callable(self._channel)
- with self._control.fail():
- with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable.with_call(
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ response_future = multi_callable.future(
request,
- metadata=(('test', 'FailedUnaryRequestBlockingUnaryResponse'),))
-
- self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
-
- def testFailedUnaryRequestFutureUnaryResponse(self):
- request = b'\x37\x17'
- callback = _Callback()
-
- multi_callable = _unary_unary_multi_callable(self._channel)
- with self._control.fail():
- response_future = multi_callable.future(
- request,
- metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),))
- response_future.add_done_callback(callback)
- value_passed_to_callback = callback.value()
-
- self.assertIsInstance(response_future, grpc.Future)
- self.assertIsInstance(response_future, grpc.Call)
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIs(
- grpc.StatusCode.UNKNOWN, exception_context.exception.code())
- self.assertIsInstance(response_future.exception(), grpc.RpcError)
- self.assertIsNotNone(response_future.traceback())
- self.assertIs(grpc.StatusCode.UNKNOWN, response_future.exception().code())
- self.assertIs(response_future, value_passed_to_callback)
-
- def testFailedUnaryRequestStreamResponse(self):
- request = b'\x37\x17'
-
- multi_callable = _unary_stream_multi_callable(self._channel)
- with self.assertRaises(grpc.RpcError) as exception_context:
- with self._control.fail():
+ metadata=(('test', 'SuccessfulUnaryRequestFutureUnaryResponse'),))
+ response = response_future.result()
+
+ self.assertIsInstance(response_future, grpc.Future)
+ self.assertIsInstance(response_future, grpc.Call)
+ self.assertEqual(expected_response, response)
+ self.assertIsNone(response_future.exception())
+ self.assertIsNone(response_future.traceback())
+
+ def testSuccessfulUnaryRequestStreamResponse(self):
+ request = b'\x37\x58'
+ expected_responses = tuple(
+ self._handler.handle_unary_stream(request, None))
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request,
- metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
- next(response_iterator)
+ metadata=(('test', 'SuccessfulUnaryRequestStreamResponse'),))
+ responses = tuple(response_iterator)
- self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
+ self.assertSequenceEqual(expected_responses, responses)
- def testFailedStreamRequestBlockingUnaryResponse(self):
- requests = tuple(b'\x47\x58' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
+ def testSuccessfulStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(
+ iter(requests), None)
+ request_iterator = iter(requests)
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self._control.fail():
- with self.assertRaises(grpc.RpcError) as exception_context:
- multi_callable(
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response = multi_callable(
+ request_iterator,
+ metadata=(
+ ('test', 'SuccessfulStreamRequestBlockingUnaryResponse'),))
+
+ self.assertEqual(expected_response, response)
+
+ def testSuccessfulStreamRequestBlockingUnaryResponseWithCall(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(
+ iter(requests), None)
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response, call = multi_callable.with_call(
+ request_iterator,
+ metadata=(
+ ('test',
+ 'SuccessfulStreamRequestBlockingUnaryResponseWithCall'),))
+
+ self.assertEqual(expected_response, response)
+ self.assertIs(grpc.StatusCode.OK, call.code())
+
+ def testSuccessfulStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(
+ iter(requests), None)
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ response_future = multi_callable.future(
request_iterator,
- metadata=(('test', 'FailedStreamRequestBlockingUnaryResponse'),))
-
- self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
-
- def testFailedStreamRequestFutureUnaryResponse(self):
- requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
- callback = _Callback()
-
- multi_callable = _stream_unary_multi_callable(self._channel)
- with self._control.fail():
- response_future = multi_callable.future(
- request_iterator,
- metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),))
- response_future.add_done_callback(callback)
- value_passed_to_callback = callback.value()
-
- with self.assertRaises(grpc.RpcError) as exception_context:
- response_future.result()
- self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code())
- self.assertIs(
- grpc.StatusCode.UNKNOWN, exception_context.exception.code())
- self.assertIsInstance(response_future.exception(), grpc.RpcError)
- self.assertIsNotNone(response_future.traceback())
- self.assertIs(response_future, value_passed_to_callback)
-
- def testFailedStreamRequestStreamResponse(self):
- requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
-
- multi_callable = _stream_stream_multi_callable(self._channel)
- with self._control.fail():
- with self.assertRaises(grpc.RpcError) as exception_context:
+ metadata=(('test', 'SuccessfulStreamRequestFutureUnaryResponse'),))
+ response = response_future.result()
+
+ self.assertEqual(expected_response, response)
+ self.assertIsNone(response_future.exception())
+ self.assertIsNone(response_future.traceback())
+
+ def testSuccessfulStreamRequestStreamResponse(self):
+ requests = tuple(b'\x77\x58'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_responses = tuple(
+ self._handler.handle_stream_stream(iter(requests), None))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
response_iterator = multi_callable(
request_iterator,
- metadata=(('test', 'FailedStreamRequestStreamResponse'),))
- tuple(response_iterator)
+ metadata=(('test', 'SuccessfulStreamRequestStreamResponse'),))
+ responses = tuple(response_iterator)
+
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testSequentialInvocations(self):
+ first_request = b'\x07\x08'
+ second_request = b'\x0809'
+ expected_first_response = self._handler.handle_unary_unary(
+ first_request, None)
+ expected_second_response = self._handler.handle_unary_unary(
+ second_request, None)
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ first_response = multi_callable(
+ first_request, metadata=(('test', 'SequentialInvocations'),))
+ second_response = multi_callable(
+ second_request, metadata=(('test', 'SequentialInvocations'),))
+
+ self.assertEqual(expected_first_response, first_response)
+ self.assertEqual(expected_second_response, second_response)
+
+ def testConcurrentBlockingInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(
+ iter(requests), None)
+ expected_responses = [expected_response
+ ] * test_constants.THREAD_CONCURRENCY
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request_iterator = iter(requests)
+ response_future = pool.submit(
+ multi_callable,
+ request_iterator,
+ metadata=(('test', 'ConcurrentBlockingInvocations'),))
+ response_futures[index] = response_future
+ responses = tuple(response_future.result()
+ for response_future in response_futures)
+
+ pool.shutdown(wait=True)
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testConcurrentFutureInvocations(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ expected_response = self._handler.handle_stream_unary(
+ iter(requests), None)
+ expected_responses = [expected_response
+ ] * test_constants.THREAD_CONCURRENCY
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request_iterator = iter(requests)
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=(('test', 'ConcurrentFutureInvocations'),))
+ response_futures[index] = response_future
+ responses = tuple(response_future.result()
+ for response_future in response_futures)
+
+ self.assertSequenceEqual(expected_responses, responses)
+
+ def testWaitingForSomeButNotAllConcurrentFutureInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ request = b'\x67\x68'
+ expected_response = self._handler.handle_unary_unary(request, None)
+ response_futures = [None] * test_constants.THREAD_CONCURRENCY
+ lock = threading.Lock()
+ test_is_running_cell = [True]
+
+ def wrap_future(future):
+
+ def wrap():
+ try:
+ return future.result()
+ except grpc.RpcError:
+ with lock:
+ if test_is_running_cell[0]:
+ raise
+ return None
+
+ return wrap
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ inner_response_future = multi_callable.future(
+ request,
+ metadata=(
+ ('test',
+ 'WaitingForSomeButNotAllConcurrentFutureInvocations'),))
+ outer_response_future = pool.submit(
+ wrap_future(inner_response_future))
+ response_futures[index] = outer_response_future
+
+ some_completed_response_futures_iterator = itertools.islice(
+ futures.as_completed(response_futures),
+ test_constants.THREAD_CONCURRENCY // 2)
+ for response_future in some_completed_response_futures_iterator:
+ self.assertEqual(expected_response, response_future.result())
+ with lock:
+ test_is_running_cell[0] = False
+
+ def testConsumingOneStreamResponseUnaryRequest(self):
+ request = b'\x57\x38'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=(('test', 'ConsumingOneStreamResponseUnaryRequest'),))
+ next(response_iterator)
+
+ def testConsumingSomeButNotAllStreamResponsesUnaryRequest(self):
+ request = b'\x57\x38'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request,
+ metadata=(
+ ('test', 'ConsumingSomeButNotAllStreamResponsesUnaryRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(response_iterator)
- self.assertIs(grpc.StatusCode.UNKNOWN, exception_context.exception.code())
- self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
+ def testConsumingSomeButNotAllStreamResponsesStreamRequest(self):
+ requests = tuple(b'\x67\x88'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
- def testIgnoredUnaryRequestFutureUnaryResponse(self):
- request = b'\x37\x17'
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(('test',
+ 'ConsumingSomeButNotAllStreamResponsesStreamRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH // 2):
+ next(response_iterator)
- multi_callable = _unary_unary_multi_callable(self._channel)
- multi_callable.future(
- request,
- metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),))
+ def testConsumingTooManyStreamResponsesStreamRequest(self):
+ requests = tuple(b'\x67\x88'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
- def testIgnoredUnaryRequestStreamResponse(self):
- request = b'\x37\x17'
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(
+ ('test', 'ConsumingTooManyStreamResponsesStreamRequest'),))
+ for _ in range(test_constants.STREAM_LENGTH):
+ next(response_iterator)
+ for _ in range(test_constants.STREAM_LENGTH):
+ with self.assertRaises(StopIteration):
+ next(response_iterator)
+
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.OK, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testCancelledUnaryRequestUnaryResponse(self):
+ request = b'\x07\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request,
+ metadata=(('test', 'CancelledUnaryRequestUnaryResponse'),))
+ response_future.cancel()
+
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.exception()
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.traceback()
+ self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
+
+ def testCancelledUnaryRequestStreamResponse(self):
+ request = b'\x07\x19'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self._control.pause():
+ response_iterator = multi_callable(
+ request,
+ metadata=(('test', 'CancelledUnaryRequestStreamResponse'),))
+ self._control.block_until_paused()
+ response_iterator.cancel()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ next(response_iterator)
+ self.assertIs(grpc.StatusCode.CANCELLED,
+ exception_context.exception.code())
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testCancelledStreamRequestUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=(('test', 'CancelledStreamRequestUnaryResponse'),))
+ self._control.block_until_paused()
+ response_future.cancel()
+
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.result()
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.exception()
+ with self.assertRaises(grpc.FutureCancelledError):
+ response_future.traceback()
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+
+ def testCancelledStreamRequestStreamResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.pause():
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(('test', 'CancelledStreamRequestStreamResponse'),))
+ response_iterator.cancel()
+
+ with self.assertRaises(grpc.RpcError):
+ next(response_iterator)
+ self.assertIsNotNone(response_iterator.initial_metadata())
+ self.assertIs(grpc.StatusCode.CANCELLED, response_iterator.code())
+ self.assertIsNotNone(response_iterator.details())
+ self.assertIsNotNone(response_iterator.trailing_metadata())
+
+ def testExpiredUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x07\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable.with_call(
+ request,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(
+ ('test', 'ExpiredUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertIsInstance(exception_context.exception, grpc.Call)
+ self.assertIsNotNone(exception_context.exception.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIsNotNone(exception_context.exception.details())
+ self.assertIsNotNone(exception_context.exception.trailing_metadata())
+
+ def testExpiredUnaryRequestFutureUnaryResponse(self):
+ request = b'\x07\x17'
+ callback = _Callback()
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(('test', 'ExpiredUnaryRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ self.assertIs(response_future, value_passed_to_callback)
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ response_future.exception().code())
+
+ def testExpiredUnaryRequestStreamResponse(self):
+ request = b'\x07\x19'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(('test', 'ExpiredUnaryRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ response_iterator.code())
+
+ def testExpiredStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x07\x08'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request_iterator,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(
+ ('test', 'ExpiredStreamRequestBlockingUnaryResponse'),))
+
+ self.assertIsInstance(exception_context.exception, grpc.RpcError)
+ self.assertIsInstance(exception_context.exception, grpc.Call)
+ self.assertIsNotNone(exception_context.exception.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIsNotNone(exception_context.exception.details())
+ self.assertIsNotNone(exception_context.exception.trailing_metadata())
+
+ def testExpiredStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+ callback = _Callback()
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.pause():
+ response_future = multi_callable.future(
+ request_iterator,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(('test', 'ExpiredStreamRequestFutureUnaryResponse'),))
+ with self.assertRaises(grpc.FutureTimeoutError):
+ response_future.result(timeout=test_constants.SHORT_TIMEOUT /
+ 2.0)
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIs(response_future, value_passed_to_callback)
+ self.assertIsNotNone(response_future.initial_metadata())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED, response_future.code())
+ self.assertIsNotNone(response_future.details())
+ self.assertIsNotNone(response_future.trailing_metadata())
+
+ def testExpiredStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x18'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.pause():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request_iterator,
+ timeout=test_constants.SHORT_TIMEOUT,
+ metadata=(('test', 'ExpiredStreamRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.DEADLINE_EXCEEDED,
+ response_iterator.code())
+
+ def testFailedUnaryRequestBlockingUnaryResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable.with_call(
+ request,
+ metadata=(
+ ('test', 'FailedUnaryRequestBlockingUnaryResponse'),))
+
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+
+ def testFailedUnaryRequestFutureUnaryResponse(self):
+ request = b'\x37\x17'
+ callback = _Callback()
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ with self._control.fail():
+ response_future = multi_callable.future(
+ request,
+ metadata=(('test', 'FailedUnaryRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ self.assertIsInstance(response_future, grpc.Future)
+ self.assertIsInstance(response_future, grpc.Call)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ response_future.exception().code())
+ self.assertIs(response_future, value_passed_to_callback)
+
+ def testFailedUnaryRequestStreamResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ with self._control.fail():
+ response_iterator = multi_callable(
+ request,
+ metadata=(('test', 'FailedUnaryRequestStreamResponse'),))
+ next(response_iterator)
+
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+
+ def testFailedStreamRequestBlockingUnaryResponse(self):
+ requests = tuple(b'\x47\x58'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ multi_callable(
+ request_iterator,
+ metadata=(
+ ('test', 'FailedStreamRequestBlockingUnaryResponse'),))
+
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+
+ def testFailedStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+ callback = _Callback()
+
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ with self._control.fail():
+ response_future = multi_callable.future(
+ request_iterator,
+ metadata=(('test', 'FailedStreamRequestFutureUnaryResponse'),))
+ response_future.add_done_callback(callback)
+ value_passed_to_callback = callback.value()
+
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_future.result()
+ self.assertIs(grpc.StatusCode.UNKNOWN, response_future.code())
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+ self.assertIsInstance(response_future.exception(), grpc.RpcError)
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIs(response_future, value_passed_to_callback)
+
+ def testFailedStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x88'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
+
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ with self._control.fail():
+ with self.assertRaises(grpc.RpcError) as exception_context:
+ response_iterator = multi_callable(
+ request_iterator,
+ metadata=(('test', 'FailedStreamRequestStreamResponse'),))
+ tuple(response_iterator)
+
+ self.assertIs(grpc.StatusCode.UNKNOWN,
+ exception_context.exception.code())
+ self.assertIs(grpc.StatusCode.UNKNOWN, response_iterator.code())
+
+ def testIgnoredUnaryRequestFutureUnaryResponse(self):
+ request = b'\x37\x17'
+
+ multi_callable = _unary_unary_multi_callable(self._channel)
+ multi_callable.future(
+ request,
+ metadata=(('test', 'IgnoredUnaryRequestFutureUnaryResponse'),))
- multi_callable = _unary_stream_multi_callable(self._channel)
- multi_callable(
- request,
- metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
+ def testIgnoredUnaryRequestStreamResponse(self):
+ request = b'\x37\x17'
- def testIgnoredStreamRequestFutureUnaryResponse(self):
- requests = tuple(b'\x07\x18' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
+ multi_callable = _unary_stream_multi_callable(self._channel)
+ multi_callable(
+ request, metadata=(('test', 'IgnoredUnaryRequestStreamResponse'),))
+
+ def testIgnoredStreamRequestFutureUnaryResponse(self):
+ requests = tuple(b'\x07\x18'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
- multi_callable = _stream_unary_multi_callable(self._channel)
- multi_callable.future(
- request_iterator,
- metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),))
+ multi_callable = _stream_unary_multi_callable(self._channel)
+ multi_callable.future(
+ request_iterator,
+ metadata=(('test', 'IgnoredStreamRequestFutureUnaryResponse'),))
- def testIgnoredStreamRequestStreamResponse(self):
- requests = tuple(b'\x67\x88' for _ in range(test_constants.STREAM_LENGTH))
- request_iterator = iter(requests)
+ def testIgnoredStreamRequestStreamResponse(self):
+ requests = tuple(b'\x67\x88'
+ for _ in range(test_constants.STREAM_LENGTH))
+ request_iterator = iter(requests)
- multi_callable = _stream_stream_multi_callable(self._channel)
- multi_callable(
- request_iterator,
- metadata=(('test', 'IgnoredStreamRequestStreamResponse'),))
+ multi_callable = _stream_stream_multi_callable(self._channel)
+ multi_callable(
+ request_iterator,
+ metadata=(('test', 'IgnoredStreamRequestStreamResponse'),))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_sanity/__init__.py b/src/python/grpcio_tests/tests/unit/_sanity/__init__.py
index 2f88fa0412..100a624dc9 100644
--- a/src/python/grpcio_tests/tests/unit/_sanity/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/_sanity/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/_sanity/_sanity_test.py b/src/python/grpcio_tests/tests/unit/_sanity/_sanity_test.py
index e9fdf217ae..0fbe6a2b5d 100644
--- a/src/python/grpcio_tests/tests/unit/_sanity/_sanity_test.py
+++ b/src/python/grpcio_tests/tests/unit/_sanity/_sanity_test.py
@@ -38,21 +38,23 @@ import tests
class Sanity(unittest.TestCase):
- def testTestsJsonUpToDate(self):
- """Autodiscovers all test suites and checks that tests.json is up to date"""
- loader = tests.Loader()
- loader.loadTestsFromNames(['tests'])
- test_suite_names = [
- test_case_class.id().rsplit('.', 1)[0]
- for test_case_class in tests._loader.iterate_suite_cases(loader.suite)]
- test_suite_names = sorted(set(test_suite_names))
-
- tests_json_string = pkg_resources.resource_string('tests', 'tests.json')
- if six.PY3:
- tests_json_string = tests_json_string.decode()
- tests_json = json.loads(tests_json_string)
- self.assertListEqual(test_suite_names, tests_json)
+ def testTestsJsonUpToDate(self):
+ """Autodiscovers all test suites and checks that tests.json is up to date"""
+ loader = tests.Loader()
+ loader.loadTestsFromNames(['tests'])
+ test_suite_names = [
+ test_case_class.id().rsplit('.', 1)[0]
+ for test_case_class in tests._loader.iterate_suite_cases(
+ loader.suite)
+ ]
+ test_suite_names = sorted(set(test_suite_names))
+
+ tests_json_string = pkg_resources.resource_string('tests', 'tests.json')
+ if six.PY3:
+ tests_json_string = tests_json_string.decode()
+ tests_json = json.loads(tests_json_string)
+ self.assertListEqual(test_suite_names, tests_json)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py b/src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py
index 3e4f317edc..be3522f46f 100644
--- a/src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py
+++ b/src/python/grpcio_tests/tests/unit/_thread_cleanup_test.py
@@ -40,78 +40,89 @@ _EPSILON = 0.1
def cleanup(timeout):
- if timeout is not None:
- time.sleep(timeout)
- else:
- time.sleep(_LONG_TIME)
+ if timeout is not None:
+ time.sleep(timeout)
+ else:
+ time.sleep(_LONG_TIME)
def slow_cleanup(timeout):
- # Don't respect timeout
- time.sleep(_LONG_TIME)
+ # Don't respect timeout
+ time.sleep(_LONG_TIME)
class CleanupThreadTest(unittest.TestCase):
- def testTargetInvocation(self):
- event = threading.Event()
- def target(arg1, arg2, arg3=None):
- self.assertEqual('arg1', arg1)
- self.assertEqual('arg2', arg2)
- self.assertEqual('arg3', arg3)
- event.set()
-
- cleanup_thread = _common.CleanupThread(behavior=lambda x: None,
- target=target, name='test-name',
- args=('arg1', 'arg2'), kwargs={'arg3': 'arg3'})
- cleanup_thread.start()
- cleanup_thread.join()
- self.assertEqual(cleanup_thread.name, 'test-name')
- self.assertTrue(event.is_set())
-
- def testJoinNoTimeout(self):
- cleanup_thread = _common.CleanupThread(behavior=cleanup)
- cleanup_thread.start()
- start_time = time.time()
- cleanup_thread.join()
- end_time = time.time()
- self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
-
- def testJoinTimeout(self):
- cleanup_thread = _common.CleanupThread(behavior=cleanup)
- cleanup_thread.start()
- start_time = time.time()
- cleanup_thread.join(_SHORT_TIME)
- end_time = time.time()
- self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
-
- def testJoinTimeoutSlowBehavior(self):
- cleanup_thread = _common.CleanupThread(behavior=slow_cleanup)
- cleanup_thread.start()
- start_time = time.time()
- cleanup_thread.join(_SHORT_TIME)
- end_time = time.time()
- self.assertAlmostEqual(_LONG_TIME, end_time - start_time, delta=_EPSILON)
-
- def testJoinTimeoutSlowTarget(self):
- event = threading.Event()
- def target():
- event.wait(_LONG_TIME)
- cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target)
- cleanup_thread.start()
- start_time = time.time()
- cleanup_thread.join(_SHORT_TIME)
- end_time = time.time()
- self.assertAlmostEqual(_SHORT_TIME, end_time - start_time, delta=_EPSILON)
- event.set()
-
- def testJoinZeroTimeout(self):
- cleanup_thread = _common.CleanupThread(behavior=cleanup)
- cleanup_thread.start()
- start_time = time.time()
- cleanup_thread.join(0)
- end_time = time.time()
- self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON)
+ def testTargetInvocation(self):
+ event = threading.Event()
+
+ def target(arg1, arg2, arg3=None):
+ self.assertEqual('arg1', arg1)
+ self.assertEqual('arg2', arg2)
+ self.assertEqual('arg3', arg3)
+ event.set()
+
+ cleanup_thread = _common.CleanupThread(
+ behavior=lambda x: None,
+ target=target,
+ name='test-name',
+ args=('arg1', 'arg2'),
+ kwargs={'arg3': 'arg3'})
+ cleanup_thread.start()
+ cleanup_thread.join()
+ self.assertEqual(cleanup_thread.name, 'test-name')
+ self.assertTrue(event.is_set())
+
+ def testJoinNoTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join()
+ end_time = time.time()
+ self.assertAlmostEqual(
+ _LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(
+ _SHORT_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowBehavior(self):
+ cleanup_thread = _common.CleanupThread(behavior=slow_cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(
+ _LONG_TIME, end_time - start_time, delta=_EPSILON)
+
+ def testJoinTimeoutSlowTarget(self):
+ event = threading.Event()
+
+ def target():
+ event.wait(_LONG_TIME)
+
+ cleanup_thread = _common.CleanupThread(behavior=cleanup, target=target)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(_SHORT_TIME)
+ end_time = time.time()
+ self.assertAlmostEqual(
+ _SHORT_TIME, end_time - start_time, delta=_EPSILON)
+ event.set()
+
+ def testJoinZeroTimeout(self):
+ cleanup_thread = _common.CleanupThread(behavior=cleanup)
+ cleanup_thread.start()
+ start_time = time.time()
+ cleanup_thread.join(0)
+ end_time = time.time()
+ self.assertAlmostEqual(0, end_time - start_time, delta=_EPSILON)
+
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/_thread_pool.py b/src/python/grpcio_tests/tests/unit/_thread_pool.py
index f13cc2f86f..fad2e1c8f6 100644
--- a/src/python/grpcio_tests/tests/unit/_thread_pool.py
+++ b/src/python/grpcio_tests/tests/unit/_thread_pool.py
@@ -32,17 +32,18 @@ from concurrent import futures
class RecordingThreadPool(futures.Executor):
- """A thread pool that records if used."""
- def __init__(self, max_workers):
- self._tp_executor = futures.ThreadPoolExecutor(max_workers=max_workers)
- self._lock = threading.Lock()
- self._was_used = False
+ """A thread pool that records if used."""
- def submit(self, fn, *args, **kwargs):
- with self._lock:
- self._was_used = True
- self._tp_executor.submit(fn, *args, **kwargs)
+ def __init__(self, max_workers):
+ self._tp_executor = futures.ThreadPoolExecutor(max_workers=max_workers)
+ self._lock = threading.Lock()
+ self._was_used = False
- def was_used(self):
- with self._lock:
- return self._was_used
+ def submit(self, fn, *args, **kwargs):
+ with self._lock:
+ self._was_used = True
+ self._tp_executor.submit(fn, *args, **kwargs)
+
+ def was_used(self):
+ with self._lock:
+ return self._was_used
diff --git a/src/python/grpcio_tests/tests/unit/beta/__init__.py b/src/python/grpcio_tests/tests/unit/beta/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/beta/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/beta/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py b/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py
index 3a9701b8eb..b5fdac26c1 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_beta_features_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests Face interface compliance of the gRPC Python Beta API."""
import threading
@@ -57,290 +56,303 @@ _RESPONSE = b'123'
class _Servicer(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._peer = None
- self._serviced = False
-
- def unary_unary(self, request, context):
- with self._condition:
- self._request = request
- self._peer = context.protocol_context().peer()
- self._invocation_metadata = context.invocation_metadata()
- context.protocol_context().disable_next_response_compression()
- self._serviced = True
- self._condition.notify_all()
- return _RESPONSE
-
- def unary_stream(self, request, context):
- with self._condition:
- self._request = request
- self._peer = context.protocol_context().peer()
- self._invocation_metadata = context.invocation_metadata()
- context.protocol_context().disable_next_response_compression()
- self._serviced = True
- self._condition.notify_all()
- return
- yield
-
- def stream_unary(self, request_iterator, context):
- for request in request_iterator:
- self._request = request
- with self._condition:
- self._peer = context.protocol_context().peer()
- self._invocation_metadata = context.invocation_metadata()
- context.protocol_context().disable_next_response_compression()
- self._serviced = True
- self._condition.notify_all()
- return _RESPONSE
-
- def stream_stream(self, request_iterator, context):
- for request in request_iterator:
- with self._condition:
- self._peer = context.protocol_context().peer()
- context.protocol_context().disable_next_response_compression()
- yield _RESPONSE
- with self._condition:
- self._invocation_metadata = context.invocation_metadata()
- self._serviced = True
- self._condition.notify_all()
-
- def peer(self):
- with self._condition:
- return self._peer
-
- def block_until_serviced(self):
- with self._condition:
- while not self._serviced:
- self._condition.wait()
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._peer = None
+ self._serviced = False
+
+ def unary_unary(self, request, context):
+ with self._condition:
+ self._request = request
+ self._peer = context.protocol_context().peer()
+ self._invocation_metadata = context.invocation_metadata()
+ context.protocol_context().disable_next_response_compression()
+ self._serviced = True
+ self._condition.notify_all()
+ return _RESPONSE
+
+ def unary_stream(self, request, context):
+ with self._condition:
+ self._request = request
+ self._peer = context.protocol_context().peer()
+ self._invocation_metadata = context.invocation_metadata()
+ context.protocol_context().disable_next_response_compression()
+ self._serviced = True
+ self._condition.notify_all()
+ return
+ yield
+
+ def stream_unary(self, request_iterator, context):
+ for request in request_iterator:
+ self._request = request
+ with self._condition:
+ self._peer = context.protocol_context().peer()
+ self._invocation_metadata = context.invocation_metadata()
+ context.protocol_context().disable_next_response_compression()
+ self._serviced = True
+ self._condition.notify_all()
+ return _RESPONSE
+
+ def stream_stream(self, request_iterator, context):
+ for request in request_iterator:
+ with self._condition:
+ self._peer = context.protocol_context().peer()
+ context.protocol_context().disable_next_response_compression()
+ yield _RESPONSE
+ with self._condition:
+ self._invocation_metadata = context.invocation_metadata()
+ self._serviced = True
+ self._condition.notify_all()
+
+ def peer(self):
+ with self._condition:
+ return self._peer
+
+ def block_until_serviced(self):
+ with self._condition:
+ while not self._serviced:
+ self._condition.wait()
class _BlockingIterator(object):
- def __init__(self, upstream):
- self._condition = threading.Condition()
- self._upstream = upstream
- self._allowed = []
+ def __init__(self, upstream):
+ self._condition = threading.Condition()
+ self._upstream = upstream
+ self._allowed = []
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- return self.next()
+ def __next__(self):
+ return self.next()
- def next(self):
- with self._condition:
- while True:
- if self._allowed is None:
- raise StopIteration()
- elif self._allowed:
- return self._allowed.pop(0)
- else:
- self._condition.wait()
+ def next(self):
+ with self._condition:
+ while True:
+ if self._allowed is None:
+ raise StopIteration()
+ elif self._allowed:
+ return self._allowed.pop(0)
+ else:
+ self._condition.wait()
- def allow(self):
- with self._condition:
- try:
- self._allowed.append(next(self._upstream))
- except StopIteration:
- self._allowed = None
- self._condition.notify_all()
+ def allow(self):
+ with self._condition:
+ try:
+ self._allowed.append(next(self._upstream))
+ except StopIteration:
+ self._allowed = None
+ self._condition.notify_all()
def _metadata_plugin(context, callback):
- callback([(_PER_RPC_CREDENTIALS_METADATA_KEY,
- _PER_RPC_CREDENTIALS_METADATA_VALUE)], None)
+ callback([(_PER_RPC_CREDENTIALS_METADATA_KEY,
+ _PER_RPC_CREDENTIALS_METADATA_VALUE)], None)
class BetaFeaturesTest(unittest.TestCase):
- def setUp(self):
- self._servicer = _Servicer()
- method_implementations = {
- (_GROUP, _UNARY_UNARY):
+ def setUp(self):
+ self._servicer = _Servicer()
+ method_implementations = {
+ (_GROUP, _UNARY_UNARY):
utilities.unary_unary_inline(self._servicer.unary_unary),
- (_GROUP, _UNARY_STREAM):
+ (_GROUP, _UNARY_STREAM):
utilities.unary_stream_inline(self._servicer.unary_stream),
- (_GROUP, _STREAM_UNARY):
+ (_GROUP, _STREAM_UNARY):
utilities.stream_unary_inline(self._servicer.stream_unary),
- (_GROUP, _STREAM_STREAM):
+ (_GROUP, _STREAM_STREAM):
utilities.stream_stream_inline(self._servicer.stream_stream),
- }
-
- cardinalities = {
- _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
- _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
- _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
- _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
- }
-
- server_options = implementations.server_options(
- thread_pool_size=test_constants.POOL_SIZE)
- self._server = implementations.server(
- method_implementations, options=server_options)
- server_credentials = implementations.ssl_server_credentials(
- [(resources.private_key(), resources.certificate_chain(),),])
- port = self._server.add_secure_port('[::]:0', server_credentials)
- self._server.start()
- self._channel_credentials = implementations.ssl_channel_credentials(
- resources.test_root_certificates())
- self._call_credentials = implementations.metadata_call_credentials(
- _metadata_plugin)
- channel = test_utilities.not_really_secure_channel(
- 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
- stub_options = implementations.stub_options(
- thread_pool_size=test_constants.POOL_SIZE)
- self._dynamic_stub = implementations.dynamic_stub(
- channel, _GROUP, cardinalities, options=stub_options)
-
- def tearDown(self):
- self._dynamic_stub = None
- self._server.stop(test_constants.SHORT_TIMEOUT).wait()
-
- def test_unary_unary(self):
- call_options = interfaces.grpc_call_options(
- disable_compression=True, credentials=self._call_credentials)
- response = getattr(self._dynamic_stub, _UNARY_UNARY)(
- _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
- self.assertEqual(_RESPONSE, response)
- self.assertIsNotNone(self._servicer.peer())
- invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
- self._servicer._invocation_metadata]
- self.assertIn(
- (_PER_RPC_CREDENTIALS_METADATA_KEY,
- _PER_RPC_CREDENTIALS_METADATA_VALUE),
- invocation_metadata)
-
- def test_unary_stream(self):
- call_options = interfaces.grpc_call_options(
- disable_compression=True, credentials=self._call_credentials)
- response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
- _REQUEST, test_constants.LONG_TIMEOUT, protocol_options=call_options)
- self._servicer.block_until_serviced()
- self.assertIsNotNone(self._servicer.peer())
- invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
- self._servicer._invocation_metadata]
- self.assertIn(
- (_PER_RPC_CREDENTIALS_METADATA_KEY,
- _PER_RPC_CREDENTIALS_METADATA_VALUE),
- invocation_metadata)
-
- def test_stream_unary(self):
- call_options = interfaces.grpc_call_options(
- credentials=self._call_credentials)
- request_iterator = _BlockingIterator(iter((_REQUEST,)))
- response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
- request_iterator, test_constants.LONG_TIMEOUT,
- protocol_options=call_options)
- response_future.protocol_context().disable_next_request_compression()
- request_iterator.allow()
- response_future.protocol_context().disable_next_request_compression()
- request_iterator.allow()
- self._servicer.block_until_serviced()
- self.assertIsNotNone(self._servicer.peer())
- self.assertEqual(_RESPONSE, response_future.result())
- invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
- self._servicer._invocation_metadata]
- self.assertIn(
- (_PER_RPC_CREDENTIALS_METADATA_KEY,
- _PER_RPC_CREDENTIALS_METADATA_VALUE),
- invocation_metadata)
-
- def test_stream_stream(self):
- call_options = interfaces.grpc_call_options(
- credentials=self._call_credentials)
- request_iterator = _BlockingIterator(iter((_REQUEST,)))
- response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
- request_iterator, test_constants.SHORT_TIMEOUT,
- protocol_options=call_options)
- response_iterator.protocol_context().disable_next_request_compression()
- request_iterator.allow()
- response = next(response_iterator)
- response_iterator.protocol_context().disable_next_request_compression()
- request_iterator.allow()
- self._servicer.block_until_serviced()
- self.assertIsNotNone(self._servicer.peer())
- self.assertEqual(_RESPONSE, response)
- invocation_metadata = [(metadatum.key, metadatum.value) for metadatum in
- self._servicer._invocation_metadata]
- self.assertIn(
- (_PER_RPC_CREDENTIALS_METADATA_KEY,
- _PER_RPC_CREDENTIALS_METADATA_VALUE),
- invocation_metadata)
+ }
+
+ cardinalities = {
+ _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
+ _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
+ _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
+ _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
+ }
+
+ server_options = implementations.server_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+ self._server = implementations.server(
+ method_implementations, options=server_options)
+ server_credentials = implementations.ssl_server_credentials([(
+ resources.private_key(),
+ resources.certificate_chain(),),])
+ port = self._server.add_secure_port('[::]:0', server_credentials)
+ self._server.start()
+ self._channel_credentials = implementations.ssl_channel_credentials(
+ resources.test_root_certificates())
+ self._call_credentials = implementations.metadata_call_credentials(
+ _metadata_plugin)
+ channel = test_utilities.not_really_secure_channel(
+ 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
+ stub_options = implementations.stub_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+ self._dynamic_stub = implementations.dynamic_stub(
+ channel, _GROUP, cardinalities, options=stub_options)
+
+ def tearDown(self):
+ self._dynamic_stub = None
+ self._server.stop(test_constants.SHORT_TIMEOUT).wait()
+
+ def test_unary_unary(self):
+ call_options = interfaces.grpc_call_options(
+ disable_compression=True, credentials=self._call_credentials)
+ response = getattr(self._dynamic_stub, _UNARY_UNARY)(
+ _REQUEST,
+ test_constants.LONG_TIMEOUT,
+ protocol_options=call_options)
+ self.assertEqual(_RESPONSE, response)
+ self.assertIsNotNone(self._servicer.peer())
+ invocation_metadata = [
+ (metadatum.key, metadatum.value)
+ for metadatum in self._servicer._invocation_metadata
+ ]
+ self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
+ _PER_RPC_CREDENTIALS_METADATA_VALUE),
+ invocation_metadata)
+
+ def test_unary_stream(self):
+ call_options = interfaces.grpc_call_options(
+ disable_compression=True, credentials=self._call_credentials)
+ response_iterator = getattr(self._dynamic_stub, _UNARY_STREAM)(
+ _REQUEST,
+ test_constants.LONG_TIMEOUT,
+ protocol_options=call_options)
+ self._servicer.block_until_serviced()
+ self.assertIsNotNone(self._servicer.peer())
+ invocation_metadata = [
+ (metadatum.key, metadatum.value)
+ for metadatum in self._servicer._invocation_metadata
+ ]
+ self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
+ _PER_RPC_CREDENTIALS_METADATA_VALUE),
+ invocation_metadata)
+
+ def test_stream_unary(self):
+ call_options = interfaces.grpc_call_options(
+ credentials=self._call_credentials)
+ request_iterator = _BlockingIterator(iter((_REQUEST,)))
+ response_future = getattr(self._dynamic_stub, _STREAM_UNARY).future(
+ request_iterator,
+ test_constants.LONG_TIMEOUT,
+ protocol_options=call_options)
+ response_future.protocol_context().disable_next_request_compression()
+ request_iterator.allow()
+ response_future.protocol_context().disable_next_request_compression()
+ request_iterator.allow()
+ self._servicer.block_until_serviced()
+ self.assertIsNotNone(self._servicer.peer())
+ self.assertEqual(_RESPONSE, response_future.result())
+ invocation_metadata = [
+ (metadatum.key, metadatum.value)
+ for metadatum in self._servicer._invocation_metadata
+ ]
+ self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
+ _PER_RPC_CREDENTIALS_METADATA_VALUE),
+ invocation_metadata)
+
+ def test_stream_stream(self):
+ call_options = interfaces.grpc_call_options(
+ credentials=self._call_credentials)
+ request_iterator = _BlockingIterator(iter((_REQUEST,)))
+ response_iterator = getattr(self._dynamic_stub, _STREAM_STREAM)(
+ request_iterator,
+ test_constants.SHORT_TIMEOUT,
+ protocol_options=call_options)
+ response_iterator.protocol_context().disable_next_request_compression()
+ request_iterator.allow()
+ response = next(response_iterator)
+ response_iterator.protocol_context().disable_next_request_compression()
+ request_iterator.allow()
+ self._servicer.block_until_serviced()
+ self.assertIsNotNone(self._servicer.peer())
+ self.assertEqual(_RESPONSE, response)
+ invocation_metadata = [
+ (metadatum.key, metadatum.value)
+ for metadatum in self._servicer._invocation_metadata
+ ]
+ self.assertIn((_PER_RPC_CREDENTIALS_METADATA_KEY,
+ _PER_RPC_CREDENTIALS_METADATA_VALUE),
+ invocation_metadata)
class ContextManagementAndLifecycleTest(unittest.TestCase):
- def setUp(self):
- self._servicer = _Servicer()
- self._method_implementations = {
- (_GROUP, _UNARY_UNARY):
+ def setUp(self):
+ self._servicer = _Servicer()
+ self._method_implementations = {
+ (_GROUP, _UNARY_UNARY):
utilities.unary_unary_inline(self._servicer.unary_unary),
- (_GROUP, _UNARY_STREAM):
+ (_GROUP, _UNARY_STREAM):
utilities.unary_stream_inline(self._servicer.unary_stream),
- (_GROUP, _STREAM_UNARY):
+ (_GROUP, _STREAM_UNARY):
utilities.stream_unary_inline(self._servicer.stream_unary),
- (_GROUP, _STREAM_STREAM):
+ (_GROUP, _STREAM_STREAM):
utilities.stream_stream_inline(self._servicer.stream_stream),
- }
-
- self._cardinalities = {
- _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
- _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
- _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
- _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
- }
-
- self._server_options = implementations.server_options(
- thread_pool_size=test_constants.POOL_SIZE)
- self._server_credentials = implementations.ssl_server_credentials(
- [(resources.private_key(), resources.certificate_chain(),),])
- self._channel_credentials = implementations.ssl_channel_credentials(
- resources.test_root_certificates())
- self._stub_options = implementations.stub_options(
- thread_pool_size=test_constants.POOL_SIZE)
-
- def test_stub_context(self):
- server = implementations.server(
- self._method_implementations, options=self._server_options)
- port = server.add_secure_port('[::]:0', self._server_credentials)
- server.start()
-
- channel = test_utilities.not_really_secure_channel(
- 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
- dynamic_stub = implementations.dynamic_stub(
- channel, _GROUP, self._cardinalities, options=self._stub_options)
- for _ in range(100):
- with dynamic_stub:
- pass
- for _ in range(10):
- with dynamic_stub:
- call_options = interfaces.grpc_call_options(
- disable_compression=True)
- response = getattr(dynamic_stub, _UNARY_UNARY)(
- _REQUEST, test_constants.LONG_TIMEOUT,
- protocol_options=call_options)
- self.assertEqual(_RESPONSE, response)
- self.assertIsNotNone(self._servicer.peer())
-
- server.stop(test_constants.SHORT_TIMEOUT).wait()
-
- def test_server_lifecycle(self):
- for _ in range(100):
- server = implementations.server(
- self._method_implementations, options=self._server_options)
- port = server.add_secure_port('[::]:0', self._server_credentials)
- server.start()
- server.stop(test_constants.SHORT_TIMEOUT).wait()
- for _ in range(100):
- server = implementations.server(
- self._method_implementations, options=self._server_options)
- server.add_secure_port('[::]:0', self._server_credentials)
- server.add_insecure_port('[::]:0')
- with server:
- server.stop(test_constants.SHORT_TIMEOUT)
- server.stop(test_constants.SHORT_TIMEOUT)
+ }
+
+ self._cardinalities = {
+ _UNARY_UNARY: cardinality.Cardinality.UNARY_UNARY,
+ _UNARY_STREAM: cardinality.Cardinality.UNARY_STREAM,
+ _STREAM_UNARY: cardinality.Cardinality.STREAM_UNARY,
+ _STREAM_STREAM: cardinality.Cardinality.STREAM_STREAM,
+ }
+
+ self._server_options = implementations.server_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+ self._server_credentials = implementations.ssl_server_credentials([(
+ resources.private_key(),
+ resources.certificate_chain(),),])
+ self._channel_credentials = implementations.ssl_channel_credentials(
+ resources.test_root_certificates())
+ self._stub_options = implementations.stub_options(
+ thread_pool_size=test_constants.POOL_SIZE)
+
+ def test_stub_context(self):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ port = server.add_secure_port('[::]:0', self._server_credentials)
+ server.start()
+
+ channel = test_utilities.not_really_secure_channel(
+ 'localhost', port, self._channel_credentials, _SERVER_HOST_OVERRIDE)
+ dynamic_stub = implementations.dynamic_stub(
+ channel, _GROUP, self._cardinalities, options=self._stub_options)
+ for _ in range(100):
+ with dynamic_stub:
+ pass
+ for _ in range(10):
+ with dynamic_stub:
+ call_options = interfaces.grpc_call_options(
+ disable_compression=True)
+ response = getattr(dynamic_stub, _UNARY_UNARY)(
+ _REQUEST,
+ test_constants.LONG_TIMEOUT,
+ protocol_options=call_options)
+ self.assertEqual(_RESPONSE, response)
+ self.assertIsNotNone(self._servicer.peer())
+
+ server.stop(test_constants.SHORT_TIMEOUT).wait()
+
+ def test_server_lifecycle(self):
+ for _ in range(100):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ port = server.add_secure_port('[::]:0', self._server_credentials)
+ server.start()
+ server.stop(test_constants.SHORT_TIMEOUT).wait()
+ for _ in range(100):
+ server = implementations.server(
+ self._method_implementations, options=self._server_options)
+ server.add_secure_port('[::]:0', self._server_credentials)
+ server.add_insecure_port('[::]:0')
+ with server:
+ server.stop(test_constants.SHORT_TIMEOUT)
+ server.stop(test_constants.SHORT_TIMEOUT)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py b/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py
index 5d826a269d..49d683b8a6 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_connectivity_channel_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc.beta._connectivity_channel."""
import unittest
@@ -36,13 +35,13 @@ from grpc.beta import interfaces
class ConnectivityStatesTest(unittest.TestCase):
- def testBetaConnectivityStates(self):
- self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE)
- self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING)
- self.assertIsNotNone(interfaces.ChannelConnectivity.READY)
- self.assertIsNotNone(interfaces.ChannelConnectivity.TRANSIENT_FAILURE)
- self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE)
+ def testBetaConnectivityStates(self):
+ self.assertIsNotNone(interfaces.ChannelConnectivity.IDLE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.CONNECTING)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.READY)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.TRANSIENT_FAILURE)
+ self.assertIsNotNone(interfaces.ChannelConnectivity.FATAL_FAILURE)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py b/src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py
index 3a67516906..f421442624 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_face_interface_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests Face interface compliance of the gRPC Python Beta API."""
import collections
@@ -47,94 +46,97 @@ _SERVER_HOST_OVERRIDE = 'foo.test.google.fr'
class _SerializationBehaviors(
- collections.namedtuple(
- '_SerializationBehaviors',
- ('request_serializers', 'request_deserializers', 'response_serializers',
- 'response_deserializers',))):
- pass
+ collections.namedtuple('_SerializationBehaviors', (
+ 'request_serializers',
+ 'request_deserializers',
+ 'response_serializers',
+ 'response_deserializers',))):
+ pass
def _serialization_behaviors_from_test_methods(test_methods):
- request_serializers = {}
- request_deserializers = {}
- response_serializers = {}
- response_deserializers = {}
- for (group, method), test_method in six.iteritems(test_methods):
- request_serializers[group, method] = test_method.serialize_request
- request_deserializers[group, method] = test_method.deserialize_request
- response_serializers[group, method] = test_method.serialize_response
- response_deserializers[group, method] = test_method.deserialize_response
- return _SerializationBehaviors(
- request_serializers, request_deserializers, response_serializers,
- response_deserializers)
+ request_serializers = {}
+ request_deserializers = {}
+ response_serializers = {}
+ response_deserializers = {}
+ for (group, method), test_method in six.iteritems(test_methods):
+ request_serializers[group, method] = test_method.serialize_request
+ request_deserializers[group, method] = test_method.deserialize_request
+ response_serializers[group, method] = test_method.serialize_response
+ response_deserializers[group, method] = test_method.deserialize_response
+ return _SerializationBehaviors(request_serializers, request_deserializers,
+ response_serializers, response_deserializers)
class _Implementation(test_interfaces.Implementation):
- def instantiate(
- self, methods, method_implementations, multi_method_implementation):
- serialization_behaviors = _serialization_behaviors_from_test_methods(
- methods)
- # TODO(nathaniel): Add a "groups" attribute to _digest.TestServiceDigest.
- service = next(iter(methods))[0]
- # TODO(nathaniel): Add a "cardinalities_by_group" attribute to
- # _digest.TestServiceDigest.
- cardinalities = {
- method: method_object.cardinality()
- for (group, method), method_object in six.iteritems(methods)}
-
- server_options = implementations.server_options(
- request_deserializers=serialization_behaviors.request_deserializers,
- response_serializers=serialization_behaviors.response_serializers,
- thread_pool_size=test_constants.POOL_SIZE)
- server = implementations.server(
- method_implementations, options=server_options)
- server_credentials = implementations.ssl_server_credentials(
- [(resources.private_key(), resources.certificate_chain(),),])
- port = server.add_secure_port('[::]:0', server_credentials)
- server.start()
- channel_credentials = implementations.ssl_channel_credentials(
- resources.test_root_certificates())
- channel = test_utilities.not_really_secure_channel(
- 'localhost', port, channel_credentials, _SERVER_HOST_OVERRIDE)
- stub_options = implementations.stub_options(
- request_serializers=serialization_behaviors.request_serializers,
- response_deserializers=serialization_behaviors.response_deserializers,
- thread_pool_size=test_constants.POOL_SIZE)
- generic_stub = implementations.generic_stub(channel, options=stub_options)
- dynamic_stub = implementations.dynamic_stub(
- channel, service, cardinalities, options=stub_options)
- return generic_stub, {service: dynamic_stub}, server
-
- def destantiate(self, memo):
- memo.stop(test_constants.SHORT_TIMEOUT).wait()
-
- def invocation_metadata(self):
- return grpc_test_common.INVOCATION_INITIAL_METADATA
-
- def initial_metadata(self):
- return grpc_test_common.SERVICE_INITIAL_METADATA
-
- def terminal_metadata(self):
- return grpc_test_common.SERVICE_TERMINAL_METADATA
-
- def code(self):
- return interfaces.StatusCode.OK
-
- def details(self):
- return grpc_test_common.DETAILS
-
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- return original_metadata is None or grpc_test_common.metadata_transmitted(
- original_metadata, transmitted_metadata)
+ def instantiate(self, methods, method_implementations,
+ multi_method_implementation):
+ serialization_behaviors = _serialization_behaviors_from_test_methods(
+ methods)
+ # TODO(nathaniel): Add a "groups" attribute to _digest.TestServiceDigest.
+ service = next(iter(methods))[0]
+ # TODO(nathaniel): Add a "cardinalities_by_group" attribute to
+ # _digest.TestServiceDigest.
+ cardinalities = {
+ method: method_object.cardinality()
+ for (group, method), method_object in six.iteritems(methods)
+ }
+
+ server_options = implementations.server_options(
+ request_deserializers=serialization_behaviors.request_deserializers,
+ response_serializers=serialization_behaviors.response_serializers,
+ thread_pool_size=test_constants.POOL_SIZE)
+ server = implementations.server(
+ method_implementations, options=server_options)
+ server_credentials = implementations.ssl_server_credentials([(
+ resources.private_key(),
+ resources.certificate_chain(),),])
+ port = server.add_secure_port('[::]:0', server_credentials)
+ server.start()
+ channel_credentials = implementations.ssl_channel_credentials(
+ resources.test_root_certificates())
+ channel = test_utilities.not_really_secure_channel(
+ 'localhost', port, channel_credentials, _SERVER_HOST_OVERRIDE)
+ stub_options = implementations.stub_options(
+ request_serializers=serialization_behaviors.request_serializers,
+ response_deserializers=serialization_behaviors.
+ response_deserializers,
+ thread_pool_size=test_constants.POOL_SIZE)
+ generic_stub = implementations.generic_stub(
+ channel, options=stub_options)
+ dynamic_stub = implementations.dynamic_stub(
+ channel, service, cardinalities, options=stub_options)
+ return generic_stub, {service: dynamic_stub}, server
+
+ def destantiate(self, memo):
+ memo.stop(test_constants.SHORT_TIMEOUT).wait()
+
+ def invocation_metadata(self):
+ return grpc_test_common.INVOCATION_INITIAL_METADATA
+
+ def initial_metadata(self):
+ return grpc_test_common.SERVICE_INITIAL_METADATA
+
+ def terminal_metadata(self):
+ return grpc_test_common.SERVICE_TERMINAL_METADATA
+
+ def code(self):
+ return interfaces.StatusCode.OK
+
+ def details(self):
+ return grpc_test_common.DETAILS
+
+ def metadata_transmitted(self, original_metadata, transmitted_metadata):
+ return original_metadata is None or grpc_test_common.metadata_transmitted(
+ original_metadata, transmitted_metadata)
def load_tests(loader, tests, pattern):
- return unittest.TestSuite(
- tests=tuple(
- loader.loadTestsFromTestCase(test_case_class)
- for test_case_class in test_cases.test_cases(_Implementation())))
+ return unittest.TestSuite(tests=tuple(
+ loader.loadTestsFromTestCase(test_case_class)
+ for test_case_class in test_cases.test_cases(_Implementation())))
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py b/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py
index 127f93e9bb..69bb5cc2a5 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_implementations_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests the implementations module of the gRPC Python Beta API."""
import datetime
@@ -40,31 +39,32 @@ from tests.unit import resources
class ChannelCredentialsTest(unittest.TestCase):
- def test_runtime_provided_root_certificates(self):
- channel_credentials = implementations.ssl_channel_credentials()
- self.assertIsInstance(
- channel_credentials, implementations.ChannelCredentials)
-
- def test_application_provided_root_certificates(self):
- channel_credentials = implementations.ssl_channel_credentials(
- resources.test_root_certificates())
- self.assertIsInstance(
- channel_credentials, implementations.ChannelCredentials)
+ def test_runtime_provided_root_certificates(self):
+ channel_credentials = implementations.ssl_channel_credentials()
+ self.assertIsInstance(channel_credentials,
+ implementations.ChannelCredentials)
+
+ def test_application_provided_root_certificates(self):
+ channel_credentials = implementations.ssl_channel_credentials(
+ resources.test_root_certificates())
+ self.assertIsInstance(channel_credentials,
+ implementations.ChannelCredentials)
class CallCredentialsTest(unittest.TestCase):
- def test_google_call_credentials(self):
- creds = oauth2client_client.GoogleCredentials(
- 'token', 'client_id', 'secret', 'refresh_token',
- datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/',
- 'user_agent')
- call_creds = implementations.google_call_credentials(creds)
- self.assertIsInstance(call_creds, implementations.CallCredentials)
+ def test_google_call_credentials(self):
+ creds = oauth2client_client.GoogleCredentials(
+ 'token', 'client_id', 'secret', 'refresh_token',
+ datetime.datetime(2008, 6, 24), 'https://refresh.uri.com/',
+ 'user_agent')
+ call_creds = implementations.google_call_credentials(creds)
+ self.assertIsInstance(call_creds, implementations.CallCredentials)
+
+ def test_access_token_call_credentials(self):
+ call_creds = implementations.access_token_call_credentials('token')
+ self.assertIsInstance(call_creds, implementations.CallCredentials)
- def test_access_token_call_credentials(self):
- call_creds = implementations.access_token_call_credentials('token')
- self.assertIsInstance(call_creds, implementations.CallCredentials)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py b/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py
index 37b8c49120..664e47c769 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_not_found_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of RPC-method-not-found behavior."""
import unittest
@@ -39,37 +38,38 @@ from tests.unit.framework.common import test_constants
class NotFoundTest(unittest.TestCase):
- def setUp(self):
- self._server = implementations.server({})
- port = self._server.add_insecure_port('[::]:0')
- channel = implementations.insecure_channel('localhost', port)
- self._generic_stub = implementations.generic_stub(channel)
- self._server.start()
+ def setUp(self):
+ self._server = implementations.server({})
+ port = self._server.add_insecure_port('[::]:0')
+ channel = implementations.insecure_channel('localhost', port)
+ self._generic_stub = implementations.generic_stub(channel)
+ self._server.start()
- def tearDown(self):
- self._server.stop(0).wait()
- self._generic_stub = None
+ def tearDown(self):
+ self._server.stop(0).wait()
+ self._generic_stub = None
- def test_blocking_unary_unary_not_found(self):
- with self.assertRaises(face.LocalError) as exception_assertion_context:
- self._generic_stub.blocking_unary_unary(
- 'groop', 'meffod', b'abc', test_constants.LONG_TIMEOUT,
- with_call=True)
- self.assertIs(
- exception_assertion_context.exception.code,
- interfaces.StatusCode.UNIMPLEMENTED)
+ def test_blocking_unary_unary_not_found(self):
+ with self.assertRaises(face.LocalError) as exception_assertion_context:
+ self._generic_stub.blocking_unary_unary(
+ 'groop',
+ 'meffod',
+ b'abc',
+ test_constants.LONG_TIMEOUT,
+ with_call=True)
+ self.assertIs(exception_assertion_context.exception.code,
+ interfaces.StatusCode.UNIMPLEMENTED)
- def test_future_stream_unary_not_found(self):
- rpc_future = self._generic_stub.future_stream_unary(
- 'grupe', 'mevvod', [b'def'], test_constants.LONG_TIMEOUT)
- with self.assertRaises(face.LocalError) as exception_assertion_context:
- rpc_future.result()
- self.assertIs(
- exception_assertion_context.exception.code,
- interfaces.StatusCode.UNIMPLEMENTED)
- self.assertIs(
- rpc_future.exception().code, interfaces.StatusCode.UNIMPLEMENTED)
+ def test_future_stream_unary_not_found(self):
+ rpc_future = self._generic_stub.future_stream_unary(
+ 'grupe', 'mevvod', [b'def'], test_constants.LONG_TIMEOUT)
+ with self.assertRaises(face.LocalError) as exception_assertion_context:
+ rpc_future.result()
+ self.assertIs(exception_assertion_context.exception.code,
+ interfaces.StatusCode.UNIMPLEMENTED)
+ self.assertIs(rpc_future.exception().code,
+ interfaces.StatusCode.UNIMPLEMENTED)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py b/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py
index 9cce96cc85..e8e62c322a 100644
--- a/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py
+++ b/src/python/grpcio_tests/tests/unit/beta/_utilities_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests of grpc.beta.utilities."""
import threading
@@ -41,68 +40,68 @@ from tests.unit.framework.common import test_constants
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._value = None
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._value = None
- def accept_value(self, value):
- with self._condition:
- self._value = value
- self._condition.notify_all()
+ def accept_value(self, value):
+ with self._condition:
+ self._value = value
+ self._condition.notify_all()
- def block_until_called(self):
- with self._condition:
- while self._value is None:
- self._condition.wait()
- return self._value
+ def block_until_called(self):
+ with self._condition:
+ while self._value is None:
+ self._condition.wait()
+ return self._value
class ChannelConnectivityTest(unittest.TestCase):
- def test_lonely_channel_connectivity(self):
- channel = implementations.insecure_channel('localhost', 12345)
- callback = _Callback()
-
- ready_future = utilities.channel_ready_future(channel)
- ready_future.add_done_callback(callback.accept_value)
- with self.assertRaises(future.TimeoutError):
- ready_future.result(timeout=test_constants.SHORT_TIMEOUT)
- self.assertFalse(ready_future.cancelled())
- self.assertFalse(ready_future.done())
- self.assertTrue(ready_future.running())
- ready_future.cancel()
- value_passed_to_callback = callback.block_until_called()
- self.assertIs(ready_future, value_passed_to_callback)
- self.assertTrue(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
-
- def test_immediately_connectable_channel_connectivity(self):
- server = implementations.server({})
- port = server.add_insecure_port('[::]:0')
- server.start()
- channel = implementations.insecure_channel('localhost', port)
- callback = _Callback()
-
- try:
- ready_future = utilities.channel_ready_future(channel)
- ready_future.add_done_callback(callback.accept_value)
- self.assertIsNone(
- ready_future.result(timeout=test_constants.LONG_TIMEOUT))
- value_passed_to_callback = callback.block_until_called()
- self.assertIs(ready_future, value_passed_to_callback)
- self.assertFalse(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
- # Cancellation after maturity has no effect.
- ready_future.cancel()
- self.assertFalse(ready_future.cancelled())
- self.assertTrue(ready_future.done())
- self.assertFalse(ready_future.running())
- finally:
- ready_future.cancel()
- server.stop(0)
+ def test_lonely_channel_connectivity(self):
+ channel = implementations.insecure_channel('localhost', 12345)
+ callback = _Callback()
+
+ ready_future = utilities.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ with self.assertRaises(future.TimeoutError):
+ ready_future.result(timeout=test_constants.SHORT_TIMEOUT)
+ self.assertFalse(ready_future.cancelled())
+ self.assertFalse(ready_future.done())
+ self.assertTrue(ready_future.running())
+ ready_future.cancel()
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertTrue(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+
+ def test_immediately_connectable_channel_connectivity(self):
+ server = implementations.server({})
+ port = server.add_insecure_port('[::]:0')
+ server.start()
+ channel = implementations.insecure_channel('localhost', port)
+ callback = _Callback()
+
+ try:
+ ready_future = utilities.channel_ready_future(channel)
+ ready_future.add_done_callback(callback.accept_value)
+ self.assertIsNone(
+ ready_future.result(timeout=test_constants.LONG_TIMEOUT))
+ value_passed_to_callback = callback.block_until_called()
+ self.assertIs(ready_future, value_passed_to_callback)
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+ # Cancellation after maturity has no effect.
+ ready_future.cancel()
+ self.assertFalse(ready_future.cancelled())
+ self.assertTrue(ready_future.done())
+ self.assertFalse(ready_future.running())
+ finally:
+ ready_future.cancel()
+ server.stop(0)
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/beta/test_utilities.py b/src/python/grpcio_tests/tests/unit/beta/test_utilities.py
index 692da9c97d..f542420683 100644
--- a/src/python/grpcio_tests/tests/unit/beta/test_utilities.py
+++ b/src/python/grpcio_tests/tests/unit/beta/test_utilities.py
@@ -26,16 +26,15 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test-appropriate entry points into the gRPC Python Beta API."""
import grpc
from grpc.beta import implementations
-def not_really_secure_channel(
- host, port, channel_credentials, server_host_override):
- """Creates an insecure Channel to a remote host.
+def not_really_secure_channel(host, port, channel_credentials,
+ server_host_override):
+ """Creates an insecure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
@@ -48,8 +47,8 @@ def not_really_secure_channel(
An implementations.Channel to the remote host through which RPCs may be
conducted.
"""
- target = '%s:%d' % (host, port)
- channel = grpc.secure_channel(
- target, channel_credentials,
- (('grpc.ssl_target_name_override', server_host_override,),))
- return implementations.Channel(channel)
+ target = '%s:%d' % (host, port)
+ channel = grpc.secure_channel(target, channel_credentials, ((
+ 'grpc.ssl_target_name_override',
+ server_host_override,),))
+ return implementations.Channel(channel)
diff --git a/src/python/grpcio_tests/tests/unit/framework/__init__.py b/src/python/grpcio_tests/tests/unit/framework/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/framework/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/common/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/framework/common/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py b/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py
index b6682d396c..905483c08d 100644
--- a/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py
+++ b/src/python/grpcio_tests/tests/unit/framework/common/test_constants.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Constants shared among tests throughout RPC Framework."""
# Value for maximum duration in seconds that a test is allowed for its actual
diff --git a/src/python/grpcio_tests/tests/unit/framework/common/test_control.py b/src/python/grpcio_tests/tests/unit/framework/common/test_control.py
index 088e2f8b88..af08731b1e 100644
--- a/src/python/grpcio_tests/tests/unit/framework/common/test_control.py
+++ b/src/python/grpcio_tests/tests/unit/framework/common/test_control.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Code for instructing systems under test to block or fail."""
import abc
@@ -37,7 +36,7 @@ import six
class Defect(Exception):
- """Simulates a programming defect raised into in a system under test.
+ """Simulates a programming defect raised into in a system under test.
Use of a standard exception type is too easily misconstrued as an actual
defect in either the test infrastructure or the system under test.
@@ -45,7 +44,7 @@ class Defect(Exception):
class Control(six.with_metaclass(abc.ABCMeta)):
- """An object that accepts program control from a system under test.
+ """An object that accepts program control from a system under test.
Systems under test passed a Control should call its control() method
frequently during execution. The control() method may block, raise an
@@ -53,61 +52,61 @@ class Control(six.with_metaclass(abc.ABCMeta)):
the system under test to simulate hanging, failing, or functioning.
"""
- @abc.abstractmethod
- def control(self):
- """Potentially does anything."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def control(self):
+ """Potentially does anything."""
+ raise NotImplementedError()
class PauseFailControl(Control):
- """A Control that can be used to pause or fail code under control.
+ """A Control that can be used to pause or fail code under control.
This object is only safe for use from two threads: one of the system under
test calling control and the other from the test system calling pause,
block_until_paused, and fail.
"""
- def __init__(self):
- self._condition = threading.Condition()
- self._pause = False
- self._paused = False
- self._fail = False
-
- def control(self):
- with self._condition:
- if self._fail:
- raise Defect()
-
- while self._pause:
- self._paused = True
- self._condition.notify_all()
- self._condition.wait()
- self._paused = False
-
- @contextlib.contextmanager
- def pause(self):
- """Pauses code under control while controlling code is in context."""
- with self._condition:
- self._pause = True
- yield
- with self._condition:
- self._pause = False
- self._condition.notify_all()
-
- def block_until_paused(self):
- """Blocks controlling code until code under control is paused.
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._pause = False
+ self._paused = False
+ self._fail = False
+
+ def control(self):
+ with self._condition:
+ if self._fail:
+ raise Defect()
+
+ while self._pause:
+ self._paused = True
+ self._condition.notify_all()
+ self._condition.wait()
+ self._paused = False
+
+ @contextlib.contextmanager
+ def pause(self):
+ """Pauses code under control while controlling code is in context."""
+ with self._condition:
+ self._pause = True
+ yield
+ with self._condition:
+ self._pause = False
+ self._condition.notify_all()
+
+ def block_until_paused(self):
+ """Blocks controlling code until code under control is paused.
May only be called within the context of a pause call.
"""
- with self._condition:
- while not self._paused:
- self._condition.wait()
-
- @contextlib.contextmanager
- def fail(self):
- """Fails code under control while controlling code is in context."""
- with self._condition:
- self._fail = True
- yield
- with self._condition:
- self._fail = False
+ with self._condition:
+ while not self._paused:
+ self._condition.wait()
+
+ @contextlib.contextmanager
+ def fail(self):
+ """Fails code under control while controlling code is in context."""
+ with self._condition:
+ self._fail = True
+ yield
+ with self._condition:
+ self._fail = False
diff --git a/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py b/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py
index ea2d2812ce..13ceec31a0 100644
--- a/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py
+++ b/src/python/grpcio_tests/tests/unit/framework/common/test_coverage.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Governs coverage for tests of RPCs throughout RPC Framework."""
import abc
@@ -38,80 +37,80 @@ import six
class Coverage(six.with_metaclass(abc.ABCMeta)):
- """Specification of test coverage."""
+ """Specification of test coverage."""
- @abc.abstractmethod
- def testSuccessfulUnaryRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testSuccessfulUnaryRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testSuccessfulUnaryRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testSuccessfulUnaryRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testSuccessfulStreamRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testSuccessfulStreamRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testSuccessfulStreamRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testSuccessfulStreamRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testSequentialInvocations(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testSequentialInvocations(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testParallelInvocations(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testParallelInvocations(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testWaitingForSomeButNotAllParallelInvocations(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testWaitingForSomeButNotAllParallelInvocations(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testCancelledUnaryRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testCancelledUnaryRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testCancelledUnaryRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testCancelledUnaryRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testCancelledStreamRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testCancelledStreamRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testCancelledStreamRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testCancelledStreamRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testExpiredUnaryRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testExpiredUnaryRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testExpiredUnaryRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testExpiredUnaryRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testExpiredStreamRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testExpiredStreamRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testExpiredStreamRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testExpiredStreamRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testFailedUnaryRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testFailedUnaryRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testFailedUnaryRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testFailedUnaryRequestStreamResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testFailedStreamRequestUnaryResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testFailedStreamRequestUnaryResponse(self):
+ raise NotImplementedError()
- @abc.abstractmethod
- def testFailedStreamRequestStreamResponse(self):
- raise NotImplementedError()
+ @abc.abstractmethod
+ def testFailedStreamRequestStreamResponse(self):
+ raise NotImplementedError()
diff --git a/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py b/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/framework/foundation/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py b/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py
index 330e445d43..19e8cbdd8e 100644
--- a/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py
+++ b/src/python/grpcio_tests/tests/unit/framework/foundation/_logging_pool_test.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tests for grpc.framework.foundation.logging_pool."""
import threading
@@ -39,50 +38,51 @@ _POOL_SIZE = 16
class _CallableObject(object):
- def __init__(self):
- self._lock = threading.Lock()
- self._passed_values = []
+ def __init__(self):
+ self._lock = threading.Lock()
+ self._passed_values = []
- def __call__(self, value):
- with self._lock:
- self._passed_values.append(value)
+ def __call__(self, value):
+ with self._lock:
+ self._passed_values.append(value)
- def passed_values(self):
- with self._lock:
- return tuple(self._passed_values)
+ def passed_values(self):
+ with self._lock:
+ return tuple(self._passed_values)
class LoggingPoolTest(unittest.TestCase):
- def testUpAndDown(self):
- pool = logging_pool.pool(_POOL_SIZE)
- pool.shutdown(wait=True)
+ def testUpAndDown(self):
+ pool = logging_pool.pool(_POOL_SIZE)
+ pool.shutdown(wait=True)
- with logging_pool.pool(_POOL_SIZE) as pool:
- self.assertIsNotNone(pool)
+ with logging_pool.pool(_POOL_SIZE) as pool:
+ self.assertIsNotNone(pool)
- def testTaskExecuted(self):
- test_list = []
+ def testTaskExecuted(self):
+ test_list = []
- with logging_pool.pool(_POOL_SIZE) as pool:
- pool.submit(lambda: test_list.append(object())).result()
+ with logging_pool.pool(_POOL_SIZE) as pool:
+ pool.submit(lambda: test_list.append(object())).result()
- self.assertTrue(test_list)
+ self.assertTrue(test_list)
- def testException(self):
- with logging_pool.pool(_POOL_SIZE) as pool:
- raised_exception = pool.submit(lambda: 1/0).exception()
+ def testException(self):
+ with logging_pool.pool(_POOL_SIZE) as pool:
+ raised_exception = pool.submit(lambda: 1 / 0).exception()
- self.assertIsNotNone(raised_exception)
+ self.assertIsNotNone(raised_exception)
- def testCallableObjectExecuted(self):
- callable_object = _CallableObject()
- passed_object = object()
- with logging_pool.pool(_POOL_SIZE) as pool:
- future = pool.submit(callable_object, passed_object)
- self.assertIsNone(future.result())
- self.assertSequenceEqual((passed_object,), callable_object.passed_values())
+ def testCallableObjectExecuted(self):
+ callable_object = _CallableObject()
+ passed_object = object()
+ with logging_pool.pool(_POOL_SIZE) as pool:
+ future = pool.submit(callable_object, passed_object)
+ self.assertIsNone(future.result())
+ self.assertSequenceEqual((passed_object,),
+ callable_object.passed_values())
if __name__ == '__main__':
- unittest.main(verbosity=2)
+ unittest.main(verbosity=2)
diff --git a/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py b/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py
index 098a53d5e7..2929e4dd78 100644
--- a/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py
+++ b/src/python/grpcio_tests/tests/unit/framework/foundation/stream_testing.py
@@ -26,48 +26,47 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Utilities for testing stream-related code."""
from grpc.framework.foundation import stream
class TestConsumer(stream.Consumer):
- """A stream.Consumer instrumented for testing.
+ """A stream.Consumer instrumented for testing.
Attributes:
calls: A sequence of value-termination pairs describing the history of calls
made on this object.
"""
- def __init__(self):
- self.calls = []
+ def __init__(self):
+ self.calls = []
- def consume(self, value):
- """See stream.Consumer.consume for specification."""
- self.calls.append((value, False))
+ def consume(self, value):
+ """See stream.Consumer.consume for specification."""
+ self.calls.append((value, False))
- def terminate(self):
- """See stream.Consumer.terminate for specification."""
- self.calls.append((None, True))
+ def terminate(self):
+ """See stream.Consumer.terminate for specification."""
+ self.calls.append((None, True))
- def consume_and_terminate(self, value):
- """See stream.Consumer.consume_and_terminate for specification."""
- self.calls.append((value, True))
+ def consume_and_terminate(self, value):
+ """See stream.Consumer.consume_and_terminate for specification."""
+ self.calls.append((value, True))
- def is_legal(self):
- """Reports whether or not a legal sequence of calls has been made."""
- terminated = False
- for value, terminal in self.calls:
- if terminated:
- return False
- elif terminal:
- terminated = True
- elif value is None:
- return False
- else: # pylint: disable=useless-else-on-loop
- return True
+ def is_legal(self):
+ """Reports whether or not a legal sequence of calls has been made."""
+ terminated = False
+ for value, terminal in self.calls:
+ if terminated:
+ return False
+ elif terminal:
+ terminated = True
+ elif value is None:
+ return False
+ else: # pylint: disable=useless-else-on-loop
+ return True
- def values(self):
- """Returns the sequence of values that have been passed to this Consumer."""
- return [value for value, _ in self.calls if value]
+ def values(self):
+ """Returns the sequence of values that have been passed to this Consumer."""
+ return [value for value, _ in self.calls if value]
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/__init__.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_3069_test_constant.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_3069_test_constant.py
index 1ea356c0bf..2aec25c9ef 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_3069_test_constant.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_3069_test_constant.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""A test constant working around issue 3069."""
# test_constants is referenced from specification in this module.
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/__init__.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/__init__.py
index 7086519106..b89398809f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/__init__.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/__init__.py
@@ -26,5 +26,3 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py
index e338aaa396..a79834f96f 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_blocking_invocation_inline_service.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test code for the Face layer of RPC Framework."""
from __future__ import division
@@ -50,246 +49,254 @@ from tests.unit.framework.interfaces.face import _stock_service
from tests.unit.framework.interfaces.face import test_interfaces # pylint: disable=unused-import
-class TestCase(six.with_metaclass(abc.ABCMeta, test_coverage.Coverage, unittest.TestCase)):
- """A test of the Face layer of RPC Framework.
+class TestCase(
+ six.with_metaclass(abc.ABCMeta, test_coverage.Coverage,
+ unittest.TestCase)):
+ """A test of the Face layer of RPC Framework.
Concrete subclasses must have an "implementation" attribute of type
test_interfaces.Implementation and an "invoker_constructor" attribute of type
_invocation.InvokerConstructor.
"""
- NAME = 'BlockingInvocationInlineServiceTest'
+ NAME = 'BlockingInvocationInlineServiceTest'
- def setUp(self):
- """See unittest.TestCase.setUp for full specification.
+ def setUp(self):
+ """See unittest.TestCase.setUp for full specification.
Overriding implementations must call this implementation.
"""
- self._control = test_control.PauseFailControl()
- self._digest = _digest.digest(
- _stock_service.STOCK_TEST_SERVICE, self._control, None)
+ self._control = test_control.PauseFailControl()
+ self._digest = _digest.digest(_stock_service.STOCK_TEST_SERVICE,
+ self._control, None)
- generic_stub, dynamic_stubs, self._memo = self.implementation.instantiate(
- self._digest.methods, self._digest.inline_method_implementations, None)
- self._invoker = self.invoker_constructor.construct_invoker(
- generic_stub, dynamic_stubs, self._digest.methods)
+ generic_stub, dynamic_stubs, self._memo = self.implementation.instantiate(
+ self._digest.methods, self._digest.inline_method_implementations,
+ None)
+ self._invoker = self.invoker_constructor.construct_invoker(
+ generic_stub, dynamic_stubs, self._digest.methods)
- def tearDown(self):
- """See unittest.TestCase.tearDown for full specification.
+ def tearDown(self):
+ """See unittest.TestCase.tearDown for full specification.
Overriding implementations must call this implementation.
"""
- self._invoker = None
- self.implementation.destantiate(self._memo)
-
- def testSuccessfulUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- response, call = self._invoker.blocking(group, method)(
- request, test_constants.LONG_TIMEOUT, with_call=True)
-
- test_messages.verify(request, response, self)
-
- def testSuccessfulUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- response_iterator = self._invoker.blocking(group, method)(
- request, test_constants.LONG_TIMEOUT)
- responses = list(response_iterator)
-
- test_messages.verify(request, responses, self)
-
- def testSuccessfulStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- response, call = self._invoker.blocking(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT, with_call=True)
-
- test_messages.verify(requests, response, self)
-
- def testSuccessfulStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- response_iterator = self._invoker.blocking(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT)
- responses = list(response_iterator)
-
- test_messages.verify(requests, responses, self)
-
- def testSequentialInvocations(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- first_request = test_messages.request()
- second_request = test_messages.request()
-
- first_response = self._invoker.blocking(group, method)(
- first_request, test_constants.LONG_TIMEOUT)
-
- test_messages.verify(first_request, first_response, self)
-
- second_response = self._invoker.blocking(group, method)(
- second_request, test_constants.LONG_TIMEOUT)
-
- test_messages.verify(second_request, second_response, self)
-
- def testParallelInvocations(self):
- pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = []
- response_futures = []
- for _ in range(test_constants.THREAD_CONCURRENCY):
- request = test_messages.request()
- response_future = pool.submit(
- self._invoker.blocking(group, method), request,
- test_constants.LONG_TIMEOUT)
- requests.append(request)
- response_futures.append(response_future)
-
- responses = [
- response_future.result() for response_future in response_futures]
-
- for request, response in zip(requests, responses):
- test_messages.verify(request, response, self)
- pool.shutdown(wait=True)
-
- def testWaitingForSomeButNotAllParallelInvocations(self):
- pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = []
- response_futures_to_indices = {}
- for index in range(test_constants.THREAD_CONCURRENCY):
- request = test_messages.request()
- response_future = pool.submit(
- self._invoker.blocking(group, method), request,
- test_constants.LONG_TIMEOUT)
- requests.append(request)
- response_futures_to_indices[response_future] = index
-
- some_completed_response_futures_iterator = itertools.islice(
- futures.as_completed(response_futures_to_indices),
- test_constants.THREAD_CONCURRENCY // 2)
- for response_future in some_completed_response_futures_iterator:
- index = response_futures_to_indices[response_future]
- test_messages.verify(requests[index], response_future.result(), self)
- pool.shutdown(wait=True)
-
- @unittest.skip('Cancellation impossible with blocking control flow!')
- def testCancelledUnaryRequestUnaryResponse(self):
- raise NotImplementedError()
-
- @unittest.skip('Cancellation impossible with blocking control flow!')
- def testCancelledUnaryRequestStreamResponse(self):
- raise NotImplementedError()
-
- @unittest.skip('Cancellation impossible with blocking control flow!')
- def testCancelledStreamRequestUnaryResponse(self):
- raise NotImplementedError()
-
- @unittest.skip('Cancellation impossible with blocking control flow!')
- def testCancelledStreamRequestStreamResponse(self):
- raise NotImplementedError()
-
- def testExpiredUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.pause(), self.assertRaises(
- face.ExpirationError):
- self._invoker.blocking(group, method)(
- request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
-
- def testExpiredUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.pause(), self.assertRaises(
- face.ExpirationError):
- response_iterator = self._invoker.blocking(group, method)(
- request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
- list(response_iterator)
-
- def testExpiredStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.pause(), self.assertRaises(
- face.ExpirationError):
- self._invoker.blocking(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
-
- def testExpiredStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.pause(), self.assertRaises(
- face.ExpirationError):
- response_iterator = self._invoker.blocking(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
- list(response_iterator)
-
- def testFailedUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.fail(), self.assertRaises(face.RemoteError):
- self._invoker.blocking(group, method)(
- request, test_constants.LONG_TIMEOUT)
-
- def testFailedUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.fail(), self.assertRaises(face.RemoteError):
- response_iterator = self._invoker.blocking(group, method)(
- request, test_constants.LONG_TIMEOUT)
- list(response_iterator)
-
- def testFailedStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.fail(), self.assertRaises(face.RemoteError):
- self._invoker.blocking(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT)
-
- def testFailedStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.fail(), self.assertRaises(face.RemoteError):
- response_iterator = self._invoker.blocking(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT)
- list(response_iterator)
+ self._invoker = None
+ self.implementation.destantiate(self._memo)
+
+ def testSuccessfulUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ response, call = self._invoker.blocking(group, method)(
+ request, test_constants.LONG_TIMEOUT, with_call=True)
+
+ test_messages.verify(request, response, self)
+
+ def testSuccessfulUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ response_iterator = self._invoker.blocking(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ responses = list(response_iterator)
+
+ test_messages.verify(request, responses, self)
+
+ def testSuccessfulStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ response, call = self._invoker.blocking(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT, with_call=True)
+
+ test_messages.verify(requests, response, self)
+
+ def testSuccessfulStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ response_iterator = self._invoker.blocking(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT)
+ responses = list(response_iterator)
+
+ test_messages.verify(requests, responses, self)
+
+ def testSequentialInvocations(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ first_request = test_messages.request()
+ second_request = test_messages.request()
+
+ first_response = self._invoker.blocking(group, method)(
+ first_request, test_constants.LONG_TIMEOUT)
+
+ test_messages.verify(first_request, first_response, self)
+
+ second_response = self._invoker.blocking(group, method)(
+ second_request, test_constants.LONG_TIMEOUT)
+
+ test_messages.verify(second_request, second_response, self)
+
+ def testParallelInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = []
+ response_futures = []
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ request = test_messages.request()
+ response_future = pool.submit(
+ self._invoker.blocking(group, method), request,
+ test_constants.LONG_TIMEOUT)
+ requests.append(request)
+ response_futures.append(response_future)
+
+ responses = [
+ response_future.result()
+ for response_future in response_futures
+ ]
+
+ for request, response in zip(requests, responses):
+ test_messages.verify(request, response, self)
+ pool.shutdown(wait=True)
+
+ def testWaitingForSomeButNotAllParallelInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = []
+ response_futures_to_indices = {}
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request = test_messages.request()
+ response_future = pool.submit(
+ self._invoker.blocking(group, method), request,
+ test_constants.LONG_TIMEOUT)
+ requests.append(request)
+ response_futures_to_indices[response_future] = index
+
+ some_completed_response_futures_iterator = itertools.islice(
+ futures.as_completed(response_futures_to_indices),
+ test_constants.THREAD_CONCURRENCY // 2)
+ for response_future in some_completed_response_futures_iterator:
+ index = response_futures_to_indices[response_future]
+ test_messages.verify(requests[index],
+ response_future.result(), self)
+ pool.shutdown(wait=True)
+
+ @unittest.skip('Cancellation impossible with blocking control flow!')
+ def testCancelledUnaryRequestUnaryResponse(self):
+ raise NotImplementedError()
+
+ @unittest.skip('Cancellation impossible with blocking control flow!')
+ def testCancelledUnaryRequestStreamResponse(self):
+ raise NotImplementedError()
+
+ @unittest.skip('Cancellation impossible with blocking control flow!')
+ def testCancelledStreamRequestUnaryResponse(self):
+ raise NotImplementedError()
+
+ @unittest.skip('Cancellation impossible with blocking control flow!')
+ def testCancelledStreamRequestStreamResponse(self):
+ raise NotImplementedError()
+
+ def testExpiredUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.pause(), self.assertRaises(
+ face.ExpirationError):
+ self._invoker.blocking(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+
+ def testExpiredUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.pause(), self.assertRaises(
+ face.ExpirationError):
+ response_iterator = self._invoker.blocking(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ list(response_iterator)
+
+ def testExpiredStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.pause(), self.assertRaises(
+ face.ExpirationError):
+ self._invoker.blocking(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+
+ def testExpiredStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.pause(), self.assertRaises(
+ face.ExpirationError):
+ response_iterator = self._invoker.blocking(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ list(response_iterator)
+
+ def testFailedUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.fail(), self.assertRaises(face.RemoteError):
+ self._invoker.blocking(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+
+ def testFailedUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.fail(), self.assertRaises(face.RemoteError):
+ response_iterator = self._invoker.blocking(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ list(response_iterator)
+
+ def testFailedStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.fail(), self.assertRaises(face.RemoteError):
+ self._invoker.blocking(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT)
+
+ def testFailedStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.fail(), self.assertRaises(face.RemoteError):
+ response_iterator = self._invoker.blocking(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT)
+ list(response_iterator)
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py
index f0befb0b27..0411da0a66 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_digest.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Code for making a service.TestService more amenable to use in tests."""
import collections
@@ -49,17 +48,16 @@ _IDENTITY = lambda x: x
class TestServiceDigest(
- collections.namedtuple(
- 'TestServiceDigest',
- ('methods',
- 'inline_method_implementations',
- 'event_method_implementations',
- 'multi_method_implementation',
- 'unary_unary_messages_sequences',
- 'unary_stream_messages_sequences',
- 'stream_unary_messages_sequences',
- 'stream_stream_messages_sequences',))):
- """A transformation of a service.TestService.
+ collections.namedtuple('TestServiceDigest', (
+ 'methods',
+ 'inline_method_implementations',
+ 'event_method_implementations',
+ 'multi_method_implementation',
+ 'unary_unary_messages_sequences',
+ 'unary_stream_messages_sequences',
+ 'stream_unary_messages_sequences',
+ 'stream_stream_messages_sequences',))):
+ """A transformation of a service.TestService.
Attributes:
methods: A dict from method group-name pair to test_interfaces.Method object
@@ -88,303 +86,308 @@ class TestServiceDigest(
class _BufferingConsumer(stream.Consumer):
- """A trivial Consumer that dumps what it consumes in a user-mutable buffer."""
+ """A trivial Consumer that dumps what it consumes in a user-mutable buffer."""
- def __init__(self):
- self.consumed = []
- self.terminated = False
+ def __init__(self):
+ self.consumed = []
+ self.terminated = False
- def consume(self, value):
- self.consumed.append(value)
+ def consume(self, value):
+ self.consumed.append(value)
- def terminate(self):
- self.terminated = True
+ def terminate(self):
+ self.terminated = True
- def consume_and_terminate(self, value):
- self.consumed.append(value)
- self.terminated = True
+ def consume_and_terminate(self, value):
+ self.consumed.append(value)
+ self.terminated = True
class _InlineUnaryUnaryMethod(face.MethodImplementation):
- def __init__(self, unary_unary_test_method, control):
- self._test_method = unary_unary_test_method
- self._control = control
+ def __init__(self, unary_unary_test_method, control):
+ self._test_method = unary_unary_test_method
+ self._control = control
- self.cardinality = cardinality.Cardinality.UNARY_UNARY
- self.style = style.Service.INLINE
+ self.cardinality = cardinality.Cardinality.UNARY_UNARY
+ self.style = style.Service.INLINE
- def unary_unary_inline(self, request, context):
- response_list = []
- self._test_method.service(
- request, response_list.append, context, self._control)
- return response_list.pop(0)
+ def unary_unary_inline(self, request, context):
+ response_list = []
+ self._test_method.service(request, response_list.append, context,
+ self._control)
+ return response_list.pop(0)
class _EventUnaryUnaryMethod(face.MethodImplementation):
- def __init__(self, unary_unary_test_method, control, pool):
- self._test_method = unary_unary_test_method
- self._control = control
- self._pool = pool
+ def __init__(self, unary_unary_test_method, control, pool):
+ self._test_method = unary_unary_test_method
+ self._control = control
+ self._pool = pool
- self.cardinality = cardinality.Cardinality.UNARY_UNARY
- self.style = style.Service.EVENT
+ self.cardinality = cardinality.Cardinality.UNARY_UNARY
+ self.style = style.Service.EVENT
- def unary_unary_event(self, request, response_callback, context):
- if self._pool is None:
- self._test_method.service(
- request, response_callback, context, self._control)
- else:
- self._pool.submit(
- self._test_method.service, request, response_callback, context,
- self._control)
+ def unary_unary_event(self, request, response_callback, context):
+ if self._pool is None:
+ self._test_method.service(request, response_callback, context,
+ self._control)
+ else:
+ self._pool.submit(self._test_method.service, request,
+ response_callback, context, self._control)
class _InlineUnaryStreamMethod(face.MethodImplementation):
- def __init__(self, unary_stream_test_method, control):
- self._test_method = unary_stream_test_method
- self._control = control
+ def __init__(self, unary_stream_test_method, control):
+ self._test_method = unary_stream_test_method
+ self._control = control
- self.cardinality = cardinality.Cardinality.UNARY_STREAM
- self.style = style.Service.INLINE
+ self.cardinality = cardinality.Cardinality.UNARY_STREAM
+ self.style = style.Service.INLINE
- def unary_stream_inline(self, request, context):
- response_consumer = _BufferingConsumer()
- self._test_method.service(
- request, response_consumer, context, self._control)
- for response in response_consumer.consumed:
- yield response
+ def unary_stream_inline(self, request, context):
+ response_consumer = _BufferingConsumer()
+ self._test_method.service(request, response_consumer, context,
+ self._control)
+ for response in response_consumer.consumed:
+ yield response
class _EventUnaryStreamMethod(face.MethodImplementation):
- def __init__(self, unary_stream_test_method, control, pool):
- self._test_method = unary_stream_test_method
- self._control = control
- self._pool = pool
+ def __init__(self, unary_stream_test_method, control, pool):
+ self._test_method = unary_stream_test_method
+ self._control = control
+ self._pool = pool
- self.cardinality = cardinality.Cardinality.UNARY_STREAM
- self.style = style.Service.EVENT
+ self.cardinality = cardinality.Cardinality.UNARY_STREAM
+ self.style = style.Service.EVENT
- def unary_stream_event(self, request, response_consumer, context):
- if self._pool is None:
- self._test_method.service(
- request, response_consumer, context, self._control)
- else:
- self._pool.submit(
- self._test_method.service, request, response_consumer, context,
- self._control)
+ def unary_stream_event(self, request, response_consumer, context):
+ if self._pool is None:
+ self._test_method.service(request, response_consumer, context,
+ self._control)
+ else:
+ self._pool.submit(self._test_method.service, request,
+ response_consumer, context, self._control)
class _InlineStreamUnaryMethod(face.MethodImplementation):
- def __init__(self, stream_unary_test_method, control):
- self._test_method = stream_unary_test_method
- self._control = control
+ def __init__(self, stream_unary_test_method, control):
+ self._test_method = stream_unary_test_method
+ self._control = control
- self.cardinality = cardinality.Cardinality.STREAM_UNARY
- self.style = style.Service.INLINE
+ self.cardinality = cardinality.Cardinality.STREAM_UNARY
+ self.style = style.Service.INLINE
- def stream_unary_inline(self, request_iterator, context):
- response_list = []
- request_consumer = self._test_method.service(
- response_list.append, context, self._control)
- for request in request_iterator:
- request_consumer.consume(request)
- request_consumer.terminate()
- return response_list.pop(0)
+ def stream_unary_inline(self, request_iterator, context):
+ response_list = []
+ request_consumer = self._test_method.service(response_list.append,
+ context, self._control)
+ for request in request_iterator:
+ request_consumer.consume(request)
+ request_consumer.terminate()
+ return response_list.pop(0)
class _EventStreamUnaryMethod(face.MethodImplementation):
- def __init__(self, stream_unary_test_method, control, pool):
- self._test_method = stream_unary_test_method
- self._control = control
- self._pool = pool
+ def __init__(self, stream_unary_test_method, control, pool):
+ self._test_method = stream_unary_test_method
+ self._control = control
+ self._pool = pool
- self.cardinality = cardinality.Cardinality.STREAM_UNARY
- self.style = style.Service.EVENT
+ self.cardinality = cardinality.Cardinality.STREAM_UNARY
+ self.style = style.Service.EVENT
- def stream_unary_event(self, response_callback, context):
- request_consumer = self._test_method.service(
- response_callback, context, self._control)
- if self._pool is None:
- return request_consumer
- else:
- return stream_util.ThreadSwitchingConsumer(request_consumer, self._pool)
+ def stream_unary_event(self, response_callback, context):
+ request_consumer = self._test_method.service(response_callback, context,
+ self._control)
+ if self._pool is None:
+ return request_consumer
+ else:
+ return stream_util.ThreadSwitchingConsumer(request_consumer,
+ self._pool)
class _InlineStreamStreamMethod(face.MethodImplementation):
- def __init__(self, stream_stream_test_method, control):
- self._test_method = stream_stream_test_method
- self._control = control
+ def __init__(self, stream_stream_test_method, control):
+ self._test_method = stream_stream_test_method
+ self._control = control
- self.cardinality = cardinality.Cardinality.STREAM_STREAM
- self.style = style.Service.INLINE
+ self.cardinality = cardinality.Cardinality.STREAM_STREAM
+ self.style = style.Service.INLINE
- def stream_stream_inline(self, request_iterator, context):
- response_consumer = _BufferingConsumer()
- request_consumer = self._test_method.service(
- response_consumer, context, self._control)
+ def stream_stream_inline(self, request_iterator, context):
+ response_consumer = _BufferingConsumer()
+ request_consumer = self._test_method.service(response_consumer, context,
+ self._control)
- for request in request_iterator:
- request_consumer.consume(request)
- while response_consumer.consumed:
- yield response_consumer.consumed.pop(0)
- response_consumer.terminate()
+ for request in request_iterator:
+ request_consumer.consume(request)
+ while response_consumer.consumed:
+ yield response_consumer.consumed.pop(0)
+ response_consumer.terminate()
class _EventStreamStreamMethod(face.MethodImplementation):
- def __init__(self, stream_stream_test_method, control, pool):
- self._test_method = stream_stream_test_method
- self._control = control
- self._pool = pool
+ def __init__(self, stream_stream_test_method, control, pool):
+ self._test_method = stream_stream_test_method
+ self._control = control
+ self._pool = pool
- self.cardinality = cardinality.Cardinality.STREAM_STREAM
- self.style = style.Service.EVENT
+ self.cardinality = cardinality.Cardinality.STREAM_STREAM
+ self.style = style.Service.EVENT
- def stream_stream_event(self, response_consumer, context):
- request_consumer = self._test_method.service(
- response_consumer, context, self._control)
- if self._pool is None:
- return request_consumer
- else:
- return stream_util.ThreadSwitchingConsumer(request_consumer, self._pool)
+ def stream_stream_event(self, response_consumer, context):
+ request_consumer = self._test_method.service(response_consumer, context,
+ self._control)
+ if self._pool is None:
+ return request_consumer
+ else:
+ return stream_util.ThreadSwitchingConsumer(request_consumer,
+ self._pool)
class _UnaryConsumer(stream.Consumer):
- """A Consumer that only allows consumption of exactly one value."""
-
- def __init__(self, action):
- self._lock = threading.Lock()
- self._action = action
- self._consumed = False
- self._terminated = False
-
- def consume(self, value):
- with self._lock:
- if self._consumed:
- raise ValueError('Unary consumer already consumed!')
- elif self._terminated:
- raise ValueError('Unary consumer already terminated!')
- else:
- self._consumed = True
-
- self._action(value)
-
- def terminate(self):
- with self._lock:
- if not self._consumed:
- raise ValueError('Unary consumer hasn\'t yet consumed!')
- elif self._terminated:
- raise ValueError('Unary consumer already terminated!')
- else:
- self._terminated = True
-
- def consume_and_terminate(self, value):
- with self._lock:
- if self._consumed:
- raise ValueError('Unary consumer already consumed!')
- elif self._terminated:
- raise ValueError('Unary consumer already terminated!')
- else:
- self._consumed = True
- self._terminated = True
-
- self._action(value)
+ """A Consumer that only allows consumption of exactly one value."""
+
+ def __init__(self, action):
+ self._lock = threading.Lock()
+ self._action = action
+ self._consumed = False
+ self._terminated = False
+
+ def consume(self, value):
+ with self._lock:
+ if self._consumed:
+ raise ValueError('Unary consumer already consumed!')
+ elif self._terminated:
+ raise ValueError('Unary consumer already terminated!')
+ else:
+ self._consumed = True
+
+ self._action(value)
+
+ def terminate(self):
+ with self._lock:
+ if not self._consumed:
+ raise ValueError('Unary consumer hasn\'t yet consumed!')
+ elif self._terminated:
+ raise ValueError('Unary consumer already terminated!')
+ else:
+ self._terminated = True
+
+ def consume_and_terminate(self, value):
+ with self._lock:
+ if self._consumed:
+ raise ValueError('Unary consumer already consumed!')
+ elif self._terminated:
+ raise ValueError('Unary consumer already terminated!')
+ else:
+ self._consumed = True
+ self._terminated = True
+
+ self._action(value)
class _UnaryUnaryAdaptation(object):
- def __init__(self, unary_unary_test_method):
- self._method = unary_unary_test_method
+ def __init__(self, unary_unary_test_method):
+ self._method = unary_unary_test_method
+
+ def service(self, response_consumer, context, control):
+
+ def action(request):
+ self._method.service(request,
+ response_consumer.consume_and_terminate,
+ context, control)
- def service(self, response_consumer, context, control):
- def action(request):
- self._method.service(
- request, response_consumer.consume_and_terminate, context, control)
- return _UnaryConsumer(action)
+ return _UnaryConsumer(action)
class _UnaryStreamAdaptation(object):
- def __init__(self, unary_stream_test_method):
- self._method = unary_stream_test_method
+ def __init__(self, unary_stream_test_method):
+ self._method = unary_stream_test_method
+
+ def service(self, response_consumer, context, control):
+
+ def action(request):
+ self._method.service(request, response_consumer, context, control)
- def service(self, response_consumer, context, control):
- def action(request):
- self._method.service(request, response_consumer, context, control)
- return _UnaryConsumer(action)
+ return _UnaryConsumer(action)
class _StreamUnaryAdaptation(object):
- def __init__(self, stream_unary_test_method):
- self._method = stream_unary_test_method
+ def __init__(self, stream_unary_test_method):
+ self._method = stream_unary_test_method
- def service(self, response_consumer, context, control):
- return self._method.service(
- response_consumer.consume_and_terminate, context, control)
+ def service(self, response_consumer, context, control):
+ return self._method.service(response_consumer.consume_and_terminate,
+ context, control)
class _MultiMethodImplementation(face.MultiMethodImplementation):
- def __init__(self, methods, control, pool):
- self._methods = methods
- self._control = control
- self._pool = pool
+ def __init__(self, methods, control, pool):
+ self._methods = methods
+ self._control = control
+ self._pool = pool
- def service(self, group, name, response_consumer, context):
- method = self._methods.get(group, name, None)
- if method is None:
- raise face.NoSuchMethodError(group, name)
- elif self._pool is None:
- return method(response_consumer, context, self._control)
- else:
- request_consumer = method(response_consumer, context, self._control)
- return stream_util.ThreadSwitchingConsumer(request_consumer, self._pool)
+ def service(self, group, name, response_consumer, context):
+ method = self._methods.get(group, name, None)
+ if method is None:
+ raise face.NoSuchMethodError(group, name)
+ elif self._pool is None:
+ return method(response_consumer, context, self._control)
+ else:
+ request_consumer = method(response_consumer, context, self._control)
+ return stream_util.ThreadSwitchingConsumer(request_consumer,
+ self._pool)
class _Assembly(
- collections.namedtuple(
- '_Assembly',
- ['methods', 'inlines', 'events', 'adaptations', 'messages'])):
- """An intermediate structure created when creating a TestServiceDigest."""
-
-
-def _assemble(
- scenarios, identifiers, inline_method_constructor, event_method_constructor,
- adapter, control, pool):
- """Creates an _Assembly from the given scenarios."""
- methods = {}
- inlines = {}
- events = {}
- adaptations = {}
- messages = {}
- for identifier, scenario in six.iteritems(scenarios):
- if identifier in identifiers:
- raise ValueError('Repeated identifier "(%s, %s)"!' % identifier)
-
- test_method = scenario[0]
- inline_method = inline_method_constructor(test_method, control)
- event_method = event_method_constructor(test_method, control, pool)
- adaptation = adapter(test_method)
-
- methods[identifier] = test_method
- inlines[identifier] = inline_method
- events[identifier] = event_method
- adaptations[identifier] = adaptation
- messages[identifier] = scenario[1]
-
- return _Assembly(methods, inlines, events, adaptations, messages)
+ collections.namedtuple(
+ '_Assembly',
+ ['methods', 'inlines', 'events', 'adaptations', 'messages'])):
+ """An intermediate structure created when creating a TestServiceDigest."""
+
+
+def _assemble(scenarios, identifiers, inline_method_constructor,
+ event_method_constructor, adapter, control, pool):
+ """Creates an _Assembly from the given scenarios."""
+ methods = {}
+ inlines = {}
+ events = {}
+ adaptations = {}
+ messages = {}
+ for identifier, scenario in six.iteritems(scenarios):
+ if identifier in identifiers:
+ raise ValueError('Repeated identifier "(%s, %s)"!' % identifier)
+
+ test_method = scenario[0]
+ inline_method = inline_method_constructor(test_method, control)
+ event_method = event_method_constructor(test_method, control, pool)
+ adaptation = adapter(test_method)
+
+ methods[identifier] = test_method
+ inlines[identifier] = inline_method
+ events[identifier] = event_method
+ adaptations[identifier] = adaptation
+ messages[identifier] = scenario[1]
+
+ return _Assembly(methods, inlines, events, adaptations, messages)
def digest(service, control, pool):
- """Creates a TestServiceDigest from a TestService.
+ """Creates a TestServiceDigest from a TestService.
Args:
service: A _service.TestService.
@@ -396,51 +399,48 @@ def digest(service, control, pool):
Returns:
A TestServiceDigest synthesized from the given service.TestService.
"""
- identifiers = set()
-
- unary_unary = _assemble(
- service.unary_unary_scenarios(), identifiers, _InlineUnaryUnaryMethod,
- _EventUnaryUnaryMethod, _UnaryUnaryAdaptation, control, pool)
- identifiers.update(unary_unary.inlines)
-
- unary_stream = _assemble(
- service.unary_stream_scenarios(), identifiers, _InlineUnaryStreamMethod,
- _EventUnaryStreamMethod, _UnaryStreamAdaptation, control, pool)
- identifiers.update(unary_stream.inlines)
-
- stream_unary = _assemble(
- service.stream_unary_scenarios(), identifiers, _InlineStreamUnaryMethod,
- _EventStreamUnaryMethod, _StreamUnaryAdaptation, control, pool)
- identifiers.update(stream_unary.inlines)
-
- stream_stream = _assemble(
- service.stream_stream_scenarios(), identifiers, _InlineStreamStreamMethod,
- _EventStreamStreamMethod, _IDENTITY, control, pool)
- identifiers.update(stream_stream.inlines)
-
- methods = dict(unary_unary.methods)
- methods.update(unary_stream.methods)
- methods.update(stream_unary.methods)
- methods.update(stream_stream.methods)
- adaptations = dict(unary_unary.adaptations)
- adaptations.update(unary_stream.adaptations)
- adaptations.update(stream_unary.adaptations)
- adaptations.update(stream_stream.adaptations)
- inlines = dict(unary_unary.inlines)
- inlines.update(unary_stream.inlines)
- inlines.update(stream_unary.inlines)
- inlines.update(stream_stream.inlines)
- events = dict(unary_unary.events)
- events.update(unary_stream.events)
- events.update(stream_unary.events)
- events.update(stream_stream.events)
-
- return TestServiceDigest(
- methods,
- inlines,
- events,
- _MultiMethodImplementation(adaptations, control, pool),
- unary_unary.messages,
- unary_stream.messages,
- stream_unary.messages,
- stream_stream.messages)
+ identifiers = set()
+
+ unary_unary = _assemble(service.unary_unary_scenarios(), identifiers,
+ _InlineUnaryUnaryMethod, _EventUnaryUnaryMethod,
+ _UnaryUnaryAdaptation, control, pool)
+ identifiers.update(unary_unary.inlines)
+
+ unary_stream = _assemble(service.unary_stream_scenarios(), identifiers,
+ _InlineUnaryStreamMethod, _EventUnaryStreamMethod,
+ _UnaryStreamAdaptation, control, pool)
+ identifiers.update(unary_stream.inlines)
+
+ stream_unary = _assemble(service.stream_unary_scenarios(), identifiers,
+ _InlineStreamUnaryMethod, _EventStreamUnaryMethod,
+ _StreamUnaryAdaptation, control, pool)
+ identifiers.update(stream_unary.inlines)
+
+ stream_stream = _assemble(service.stream_stream_scenarios(), identifiers,
+ _InlineStreamStreamMethod,
+ _EventStreamStreamMethod, _IDENTITY, control,
+ pool)
+ identifiers.update(stream_stream.inlines)
+
+ methods = dict(unary_unary.methods)
+ methods.update(unary_stream.methods)
+ methods.update(stream_unary.methods)
+ methods.update(stream_stream.methods)
+ adaptations = dict(unary_unary.adaptations)
+ adaptations.update(unary_stream.adaptations)
+ adaptations.update(stream_unary.adaptations)
+ adaptations.update(stream_stream.adaptations)
+ inlines = dict(unary_unary.inlines)
+ inlines.update(unary_stream.inlines)
+ inlines.update(stream_unary.inlines)
+ inlines.update(stream_stream.inlines)
+ events = dict(unary_unary.events)
+ events.update(unary_stream.events)
+ events.update(stream_unary.events)
+ events.update(stream_stream.events)
+
+ return TestServiceDigest(
+ methods, inlines, events,
+ _MultiMethodImplementation(adaptations, control, pool),
+ unary_unary.messages, unary_stream.messages, stream_unary.messages,
+ stream_stream.messages)
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py
index df620b19ba..703eef3a82 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_future_invocation_asynchronous_event_service.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Test code for the Face layer of RPC Framework."""
from __future__ import division
@@ -55,457 +54,470 @@ from tests.unit.framework.interfaces.face import test_interfaces # pylint: disa
class _PauseableIterator(object):
- def __init__(self, upstream):
- self._upstream = upstream
- self._condition = threading.Condition()
- self._paused = False
+ def __init__(self, upstream):
+ self._upstream = upstream
+ self._condition = threading.Condition()
+ self._paused = False
- @contextlib.contextmanager
- def pause(self):
- with self._condition:
- self._paused = True
- yield
- with self._condition:
- self._paused = False
- self._condition.notify_all()
+ @contextlib.contextmanager
+ def pause(self):
+ with self._condition:
+ self._paused = True
+ yield
+ with self._condition:
+ self._paused = False
+ self._condition.notify_all()
- def __iter__(self):
- return self
+ def __iter__(self):
+ return self
- def __next__(self):
- return self.next()
+ def __next__(self):
+ return self.next()
- def next(self):
- with self._condition:
- while self._paused:
- self._condition.wait()
- return next(self._upstream)
+ def next(self):
+ with self._condition:
+ while self._paused:
+ self._condition.wait()
+ return next(self._upstream)
class _Callback(object):
- def __init__(self):
- self._condition = threading.Condition()
- self._called = False
- self._passed_future = None
- self._passed_other_stuff = None
-
- def __call__(self, *args, **kwargs):
- with self._condition:
- self._called = True
- if args:
- self._passed_future = args[0]
- if 1 < len(args) or kwargs:
- self._passed_other_stuff = tuple(args[1:]), dict(kwargs)
- self._condition.notify_all()
-
- def future(self):
- with self._condition:
- while True:
- if self._passed_other_stuff is not None:
- raise ValueError(
- 'Test callback passed unexpected values: %s',
- self._passed_other_stuff)
- elif self._called:
- return self._passed_future
- else:
- self._condition.wait()
-
-
-class TestCase(six.with_metaclass(abc.ABCMeta, test_coverage.Coverage, unittest.TestCase)):
- """A test of the Face layer of RPC Framework.
+ def __init__(self):
+ self._condition = threading.Condition()
+ self._called = False
+ self._passed_future = None
+ self._passed_other_stuff = None
+
+ def __call__(self, *args, **kwargs):
+ with self._condition:
+ self._called = True
+ if args:
+ self._passed_future = args[0]
+ if 1 < len(args) or kwargs:
+ self._passed_other_stuff = tuple(args[1:]), dict(kwargs)
+ self._condition.notify_all()
+
+ def future(self):
+ with self._condition:
+ while True:
+ if self._passed_other_stuff is not None:
+ raise ValueError(
+ 'Test callback passed unexpected values: %s',
+ self._passed_other_stuff)
+ elif self._called:
+ return self._passed_future
+ else:
+ self._condition.wait()
+
+
+class TestCase(
+ six.with_metaclass(abc.ABCMeta, test_coverage.Coverage,
+ unittest.TestCase)):
+ """A test of the Face layer of RPC Framework.
Concrete subclasses must have an "implementation" attribute of type
test_interfaces.Implementation and an "invoker_constructor" attribute of type
_invocation.InvokerConstructor.
"""
- NAME = 'FutureInvocationAsynchronousEventServiceTest'
+ NAME = 'FutureInvocationAsynchronousEventServiceTest'
- def setUp(self):
- """See unittest.TestCase.setUp for full specification.
+ def setUp(self):
+ """See unittest.TestCase.setUp for full specification.
Overriding implementations must call this implementation.
"""
- self._control = test_control.PauseFailControl()
- self._digest_pool = logging_pool.pool(test_constants.POOL_SIZE)
- self._digest = _digest.digest(
- _stock_service.STOCK_TEST_SERVICE, self._control, self._digest_pool)
+ self._control = test_control.PauseFailControl()
+ self._digest_pool = logging_pool.pool(test_constants.POOL_SIZE)
+ self._digest = _digest.digest(_stock_service.STOCK_TEST_SERVICE,
+ self._control, self._digest_pool)
- generic_stub, dynamic_stubs, self._memo = self.implementation.instantiate(
- self._digest.methods, self._digest.event_method_implementations, None)
- self._invoker = self.invoker_constructor.construct_invoker(
- generic_stub, dynamic_stubs, self._digest.methods)
+ generic_stub, dynamic_stubs, self._memo = self.implementation.instantiate(
+ self._digest.methods, self._digest.event_method_implementations,
+ None)
+ self._invoker = self.invoker_constructor.construct_invoker(
+ generic_stub, dynamic_stubs, self._digest.methods)
- def tearDown(self):
- """See unittest.TestCase.tearDown for full specification.
+ def tearDown(self):
+ """See unittest.TestCase.tearDown for full specification.
Overriding implementations must call this implementation.
"""
- self._invoker = None
- self.implementation.destantiate(self._memo)
- self._digest_pool.shutdown(wait=True)
-
- def testSuccessfulUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
- callback = _Callback()
-
- response_future = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- response_future.add_done_callback(callback)
- response = response_future.result()
-
- test_messages.verify(request, response, self)
- self.assertIs(callback.future(), response_future)
- self.assertIsNone(response_future.exception())
- self.assertIsNone(response_future.traceback())
-
- def testSuccessfulUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- response_iterator = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- responses = list(response_iterator)
-
- test_messages.verify(request, responses, self)
-
- def testSuccessfulStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
- request_iterator = _PauseableIterator(iter(requests))
- callback = _Callback()
-
- # Use of a paused iterator of requests allows us to test that control is
- # returned to calling code before the iterator yields any requests.
- with request_iterator.pause():
- response_future = self._invoker.future(group, method)(
- request_iterator, test_constants.LONG_TIMEOUT)
- response_future.add_done_callback(callback)
- future_passed_to_callback = callback.future()
- response = future_passed_to_callback.result()
-
- test_messages.verify(requests, response, self)
- self.assertIs(future_passed_to_callback, response_future)
- self.assertIsNone(response_future.exception())
- self.assertIsNone(response_future.traceback())
-
- def testSuccessfulStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
- request_iterator = _PauseableIterator(iter(requests))
-
- # Use of a paused iterator of requests allows us to test that control is
- # returned to calling code before the iterator yields any requests.
- with request_iterator.pause():
- response_iterator = self._invoker.future(group, method)(
- request_iterator, test_constants.LONG_TIMEOUT)
- responses = list(response_iterator)
-
- test_messages.verify(requests, responses, self)
-
- def testSequentialInvocations(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- first_request = test_messages.request()
- second_request = test_messages.request()
-
- first_response_future = self._invoker.future(group, method)(
- first_request, test_constants.LONG_TIMEOUT)
- first_response = first_response_future.result()
-
- test_messages.verify(first_request, first_response, self)
-
- second_response_future = self._invoker.future(group, method)(
- second_request, test_constants.LONG_TIMEOUT)
- second_response = second_response_future.result()
-
- test_messages.verify(second_request, second_response, self)
-
- def testParallelInvocations(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- first_request = test_messages.request()
- second_request = test_messages.request()
-
- first_response_future = self._invoker.future(group, method)(
- first_request, test_constants.LONG_TIMEOUT)
- second_response_future = self._invoker.future(group, method)(
- second_request, test_constants.LONG_TIMEOUT)
- first_response = first_response_future.result()
- second_response = second_response_future.result()
-
- test_messages.verify(first_request, first_response, self)
- test_messages.verify(second_request, second_response, self)
-
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = []
- response_futures = []
- for _ in range(test_constants.THREAD_CONCURRENCY):
- request = test_messages.request()
- response_future = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- requests.append(request)
- response_futures.append(response_future)
-
- responses = [
- response_future.result() for response_future in response_futures]
-
- for request, response in zip(requests, responses):
- test_messages.verify(request, response, self)
-
- def testWaitingForSomeButNotAllParallelInvocations(self):
- pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = []
- response_futures_to_indices = {}
- for index in range(test_constants.THREAD_CONCURRENCY):
- request = test_messages.request()
- inner_response_future = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- outer_response_future = pool.submit(inner_response_future.result)
- requests.append(request)
- response_futures_to_indices[outer_response_future] = index
-
- some_completed_response_futures_iterator = itertools.islice(
- futures.as_completed(response_futures_to_indices),
- test_constants.THREAD_CONCURRENCY // 2)
- for response_future in some_completed_response_futures_iterator:
- index = response_futures_to_indices[response_future]
- test_messages.verify(requests[index], response_future.result(), self)
- pool.shutdown(wait=True)
-
- def testCancelledUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
- callback = _Callback()
-
- with self._control.pause():
- response_future = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- response_future.add_done_callback(callback)
- cancel_method_return_value = response_future.cancel()
-
- self.assertIs(callback.future(), response_future)
- self.assertFalse(cancel_method_return_value)
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(future.CancelledError):
- response_future.result()
- with self.assertRaises(future.CancelledError):
- response_future.exception()
- with self.assertRaises(future.CancelledError):
- response_future.traceback()
-
- def testCancelledUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.pause():
- response_iterator = self._invoker.future(group, method)(
- request, test_constants.LONG_TIMEOUT)
- response_iterator.cancel()
-
- with self.assertRaises(face.CancellationError):
- next(response_iterator)
-
- def testCancelledStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
- callback = _Callback()
-
- with self._control.pause():
- response_future = self._invoker.future(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT)
- response_future.add_done_callback(callback)
- cancel_method_return_value = response_future.cancel()
-
- self.assertIs(callback.future(), response_future)
- self.assertFalse(cancel_method_return_value)
- self.assertTrue(response_future.cancelled())
- with self.assertRaises(future.CancelledError):
- response_future.result()
- with self.assertRaises(future.CancelledError):
- response_future.exception()
- with self.assertRaises(future.CancelledError):
- response_future.traceback()
-
- def testCancelledStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.pause():
- response_iterator = self._invoker.future(group, method)(
- iter(requests), test_constants.LONG_TIMEOUT)
- response_iterator.cancel()
-
- with self.assertRaises(face.CancellationError):
- next(response_iterator)
-
- def testExpiredUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
- callback = _Callback()
-
- with self._control.pause():
- response_future = self._invoker.future(
- group, method)(request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
- response_future.add_done_callback(callback)
- self.assertIs(callback.future(), response_future)
- self.assertIsInstance(
- response_future.exception(), face.ExpirationError)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
- self.assertIsInstance(
- response_future.exception(), face.AbortionError)
- self.assertIsNotNone(response_future.traceback())
-
- def testExpiredUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- with self._control.pause():
- response_iterator = self._invoker.future(group, method)(
- request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- list(response_iterator)
-
- def testExpiredStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
- callback = _Callback()
-
- with self._control.pause():
- response_future = self._invoker.future(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
- response_future.add_done_callback(callback)
- self.assertIs(callback.future(), response_future)
- self.assertIsInstance(
- response_future.exception(), face.ExpirationError)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
- self.assertIsInstance(
- response_future.exception(), face.AbortionError)
- self.assertIsNotNone(response_future.traceback())
-
- def testExpiredStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- with self._control.pause():
- response_iterator = self._invoker.future(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
- with self.assertRaises(face.ExpirationError):
- list(response_iterator)
-
- def testFailedUnaryRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
- callback = _Callback()
- abortion_callback = _Callback()
-
- with self._control.fail():
- response_future = self._invoker.future(group, method)(
- request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
- response_future.add_done_callback(callback)
- response_future.add_abortion_callback(abortion_callback)
-
- self.assertIs(callback.future(), response_future)
- # Because the servicer fails outside of the thread from which the
- # servicer-side runtime called into it its failure is
- # indistinguishable from simply not having called its
- # response_callback before the expiration of the RPC.
- self.assertIsInstance(
- response_future.exception(), face.ExpirationError)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
- self.assertIsNotNone(response_future.traceback())
- self.assertIsNotNone(abortion_callback.future())
-
- def testFailedUnaryRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.unary_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- request = test_messages.request()
-
- # Because the servicer fails outside of the thread from which the
- # servicer-side runtime called into it its failure is indistinguishable
- # from simply not having called its response_consumer before the
- # expiration of the RPC.
- with self._control.fail(), self.assertRaises(face.ExpirationError):
- response_iterator = self._invoker.future(group, method)(
- request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
- list(response_iterator)
-
- def testFailedStreamRequestUnaryResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_unary_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
- callback = _Callback()
- abortion_callback = _Callback()
-
- with self._control.fail():
- response_future = self._invoker.future(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
- response_future.add_done_callback(callback)
- response_future.add_abortion_callback(abortion_callback)
-
- self.assertIs(callback.future(), response_future)
- # Because the servicer fails outside of the thread from which the
- # servicer-side runtime called into it its failure is
- # indistinguishable from simply not having called its
- # response_callback before the expiration of the RPC.
- self.assertIsInstance(
- response_future.exception(), face.ExpirationError)
- with self.assertRaises(face.ExpirationError):
- response_future.result()
- self.assertIsNotNone(response_future.traceback())
- self.assertIsNotNone(abortion_callback.future())
-
- def testFailedStreamRequestStreamResponse(self):
- for (group, method), test_messages_sequence in (
- six.iteritems(self._digest.stream_stream_messages_sequences)):
- for test_messages in test_messages_sequence:
- requests = test_messages.requests()
-
- # Because the servicer fails outside of the thread from which the
- # servicer-side runtime called into it its failure is indistinguishable
- # from simply not having called its response_consumer before the
- # expiration of the RPC.
- with self._control.fail(), self.assertRaises(face.ExpirationError):
- response_iterator = self._invoker.future(group, method)(
- iter(requests), _3069_test_constant.REALLY_SHORT_TIMEOUT)
- list(response_iterator)
+ self._invoker = None
+ self.implementation.destantiate(self._memo)
+ self._digest_pool.shutdown(wait=True)
+
+ def testSuccessfulUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+ callback = _Callback()
+
+ response_future = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ response_future.add_done_callback(callback)
+ response = response_future.result()
+
+ test_messages.verify(request, response, self)
+ self.assertIs(callback.future(), response_future)
+ self.assertIsNone(response_future.exception())
+ self.assertIsNone(response_future.traceback())
+
+ def testSuccessfulUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ response_iterator = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ responses = list(response_iterator)
+
+ test_messages.verify(request, responses, self)
+
+ def testSuccessfulStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+ request_iterator = _PauseableIterator(iter(requests))
+ callback = _Callback()
+
+ # Use of a paused iterator of requests allows us to test that control is
+ # returned to calling code before the iterator yields any requests.
+ with request_iterator.pause():
+ response_future = self._invoker.future(group, method)(
+ request_iterator, test_constants.LONG_TIMEOUT)
+ response_future.add_done_callback(callback)
+ future_passed_to_callback = callback.future()
+ response = future_passed_to_callback.result()
+
+ test_messages.verify(requests, response, self)
+ self.assertIs(future_passed_to_callback, response_future)
+ self.assertIsNone(response_future.exception())
+ self.assertIsNone(response_future.traceback())
+
+ def testSuccessfulStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+ request_iterator = _PauseableIterator(iter(requests))
+
+ # Use of a paused iterator of requests allows us to test that control is
+ # returned to calling code before the iterator yields any requests.
+ with request_iterator.pause():
+ response_iterator = self._invoker.future(group, method)(
+ request_iterator, test_constants.LONG_TIMEOUT)
+ responses = list(response_iterator)
+
+ test_messages.verify(requests, responses, self)
+
+ def testSequentialInvocations(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ first_request = test_messages.request()
+ second_request = test_messages.request()
+
+ first_response_future = self._invoker.future(group, method)(
+ first_request, test_constants.LONG_TIMEOUT)
+ first_response = first_response_future.result()
+
+ test_messages.verify(first_request, first_response, self)
+
+ second_response_future = self._invoker.future(group, method)(
+ second_request, test_constants.LONG_TIMEOUT)
+ second_response = second_response_future.result()
+
+ test_messages.verify(second_request, second_response, self)
+
+ def testParallelInvocations(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ first_request = test_messages.request()
+ second_request = test_messages.request()
+
+ first_response_future = self._invoker.future(group, method)(
+ first_request, test_constants.LONG_TIMEOUT)
+ second_response_future = self._invoker.future(group, method)(
+ second_request, test_constants.LONG_TIMEOUT)
+ first_response = first_response_future.result()
+ second_response = second_response_future.result()
+
+ test_messages.verify(first_request, first_response, self)
+ test_messages.verify(second_request, second_response, self)
+
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = []
+ response_futures = []
+ for _ in range(test_constants.THREAD_CONCURRENCY):
+ request = test_messages.request()
+ response_future = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ requests.append(request)
+ response_futures.append(response_future)
+
+ responses = [
+ response_future.result()
+ for response_future in response_futures
+ ]
+
+ for request, response in zip(requests, responses):
+ test_messages.verify(request, response, self)
+
+ def testWaitingForSomeButNotAllParallelInvocations(self):
+ pool = logging_pool.pool(test_constants.THREAD_CONCURRENCY)
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = []
+ response_futures_to_indices = {}
+ for index in range(test_constants.THREAD_CONCURRENCY):
+ request = test_messages.request()
+ inner_response_future = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ outer_response_future = pool.submit(
+ inner_response_future.result)
+ requests.append(request)
+ response_futures_to_indices[outer_response_future] = index
+
+ some_completed_response_futures_iterator = itertools.islice(
+ futures.as_completed(response_futures_to_indices),
+ test_constants.THREAD_CONCURRENCY // 2)
+ for response_future in some_completed_response_futures_iterator:
+ index = response_futures_to_indices[response_future]
+ test_messages.verify(requests[index],
+ response_future.result(), self)
+ pool.shutdown(wait=True)
+
+ def testCancelledUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+ callback = _Callback()
+
+ with self._control.pause():
+ response_future = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ response_future.add_done_callback(callback)
+ cancel_method_return_value = response_future.cancel()
+
+ self.assertIs(callback.future(), response_future)
+ self.assertFalse(cancel_method_return_value)
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(future.CancelledError):
+ response_future.result()
+ with self.assertRaises(future.CancelledError):
+ response_future.exception()
+ with self.assertRaises(future.CancelledError):
+ response_future.traceback()
+
+ def testCancelledUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.pause():
+ response_iterator = self._invoker.future(group, method)(
+ request, test_constants.LONG_TIMEOUT)
+ response_iterator.cancel()
+
+ with self.assertRaises(face.CancellationError):
+ next(response_iterator)
+
+ def testCancelledStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+ callback = _Callback()
+
+ with self._control.pause():
+ response_future = self._invoker.future(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT)
+ response_future.add_done_callback(callback)
+ cancel_method_return_value = response_future.cancel()
+
+ self.assertIs(callback.future(), response_future)
+ self.assertFalse(cancel_method_return_value)
+ self.assertTrue(response_future.cancelled())
+ with self.assertRaises(future.CancelledError):
+ response_future.result()
+ with self.assertRaises(future.CancelledError):
+ response_future.exception()
+ with self.assertRaises(future.CancelledError):
+ response_future.traceback()
+
+ def testCancelledStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.pause():
+ response_iterator = self._invoker.future(group, method)(
+ iter(requests), test_constants.LONG_TIMEOUT)
+ response_iterator.cancel()
+
+ with self.assertRaises(face.CancellationError):
+ next(response_iterator)
+
+ def testExpiredUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+ callback = _Callback()
+
+ with self._control.pause():
+ response_future = self._invoker.future(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ response_future.add_done_callback(callback)
+ self.assertIs(callback.future(), response_future)
+ self.assertIsInstance(response_future.exception(),
+ face.ExpirationError)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+ self.assertIsInstance(response_future.exception(),
+ face.AbortionError)
+ self.assertIsNotNone(response_future.traceback())
+
+ def testExpiredUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ with self._control.pause():
+ response_iterator = self._invoker.future(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ list(response_iterator)
+
+ def testExpiredStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+ callback = _Callback()
+
+ with self._control.pause():
+ response_future = self._invoker.future(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ response_future.add_done_callback(callback)
+ self.assertIs(callback.future(), response_future)
+ self.assertIsInstance(response_future.exception(),
+ face.ExpirationError)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+ self.assertIsInstance(response_future.exception(),
+ face.AbortionError)
+ self.assertIsNotNone(response_future.traceback())
+
+ def testExpiredStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ with self._control.pause():
+ response_iterator = self._invoker.future(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ with self.assertRaises(face.ExpirationError):
+ list(response_iterator)
+
+ def testFailedUnaryRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+ callback = _Callback()
+ abortion_callback = _Callback()
+
+ with self._control.fail():
+ response_future = self._invoker.future(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ response_future.add_done_callback(callback)
+ response_future.add_abortion_callback(abortion_callback)
+
+ self.assertIs(callback.future(), response_future)
+ # Because the servicer fails outside of the thread from which the
+ # servicer-side runtime called into it its failure is
+ # indistinguishable from simply not having called its
+ # response_callback before the expiration of the RPC.
+ self.assertIsInstance(response_future.exception(),
+ face.ExpirationError)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIsNotNone(abortion_callback.future())
+
+ def testFailedUnaryRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.unary_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ request = test_messages.request()
+
+ # Because the servicer fails outside of the thread from which the
+ # servicer-side runtime called into it its failure is indistinguishable
+ # from simply not having called its response_consumer before the
+ # expiration of the RPC.
+ with self._control.fail(), self.assertRaises(
+ face.ExpirationError):
+ response_iterator = self._invoker.future(group, method)(
+ request, _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ list(response_iterator)
+
+ def testFailedStreamRequestUnaryResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_unary_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+ callback = _Callback()
+ abortion_callback = _Callback()
+
+ with self._control.fail():
+ response_future = self._invoker.future(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ response_future.add_done_callback(callback)
+ response_future.add_abortion_callback(abortion_callback)
+
+ self.assertIs(callback.future(), response_future)
+ # Because the servicer fails outside of the thread from which the
+ # servicer-side runtime called into it its failure is
+ # indistinguishable from simply not having called its
+ # response_callback before the expiration of the RPC.
+ self.assertIsInstance(response_future.exception(),
+ face.ExpirationError)
+ with self.assertRaises(face.ExpirationError):
+ response_future.result()
+ self.assertIsNotNone(response_future.traceback())
+ self.assertIsNotNone(abortion_callback.future())
+
+ def testFailedStreamRequestStreamResponse(self):
+ for (group, method), test_messages_sequence in (
+ six.iteritems(self._digest.stream_stream_messages_sequences)):
+ for test_messages in test_messages_sequence:
+ requests = test_messages.requests()
+
+ # Because the servicer fails outside of the thread from which the
+ # servicer-side runtime called into it its failure is indistinguishable
+ # from simply not having called its response_consumer before the
+ # expiration of the RPC.
+ with self._control.fail(), self.assertRaises(
+ face.ExpirationError):
+ response_iterator = self._invoker.future(group, method)(
+ iter(requests),
+ _3069_test_constant.REALLY_SHORT_TIMEOUT)
+ list(response_iterator)
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py
index ac487bed4f..4e144a3635 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_invocation.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Coverage across the Face layer's generic-to-dynamic range for invocation."""
import abc
@@ -65,149 +64,149 @@ _CARDINALITY_TO_MULTI_CALLABLE_ATTRIBUTE = {
class Invoker(six.with_metaclass(abc.ABCMeta)):
- """A type used to invoke test RPCs."""
+ """A type used to invoke test RPCs."""
- @abc.abstractmethod
- def blocking(self, group, name):
- """Invokes an RPC with blocking control flow."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def blocking(self, group, name):
+ """Invokes an RPC with blocking control flow."""
+ raise NotImplementedError()
- @abc.abstractmethod
- def future(self, group, name):
- """Invokes an RPC with future control flow."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def future(self, group, name):
+ """Invokes an RPC with future control flow."""
+ raise NotImplementedError()
- @abc.abstractmethod
- def event(self, group, name):
- """Invokes an RPC with event control flow."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def event(self, group, name):
+ """Invokes an RPC with event control flow."""
+ raise NotImplementedError()
class InvokerConstructor(six.with_metaclass(abc.ABCMeta)):
- """A type used to create Invokers."""
+ """A type used to create Invokers."""
- @abc.abstractmethod
- def name(self):
- """Specifies the name of the Invoker constructed by this object."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def name(self):
+ """Specifies the name of the Invoker constructed by this object."""
+ raise NotImplementedError()
- @abc.abstractmethod
- def construct_invoker(self, generic_stub, dynamic_stubs, methods):
- """Constructs an Invoker for the given stubs and methods."""
- raise NotImplementedError()
+ @abc.abstractmethod
+ def construct_invoker(self, generic_stub, dynamic_stubs, methods):
+ """Constructs an Invoker for the given stubs and methods."""
+ raise NotImplementedError()
class _GenericInvoker(Invoker):
- def __init__(self, generic_stub, methods):
- self._stub = generic_stub
- self._methods = methods
+ def __init__(self, generic_stub, methods):
+ self._stub = generic_stub
+ self._methods = methods
- def _behavior(self, group, name, cardinality_to_generic_method):
- method_cardinality = self._methods[group, name].cardinality()
- behavior = getattr(
- self._stub, cardinality_to_generic_method[method_cardinality])
- return lambda *args, **kwargs: behavior(group, name, *args, **kwargs)
+ def _behavior(self, group, name, cardinality_to_generic_method):
+ method_cardinality = self._methods[group, name].cardinality()
+ behavior = getattr(self._stub,
+ cardinality_to_generic_method[method_cardinality])
+ return lambda *args, **kwargs: behavior(group, name, *args, **kwargs)
- def blocking(self, group, name):
- return self._behavior(
- group, name, _CARDINALITY_TO_GENERIC_BLOCKING_BEHAVIOR)
+ def blocking(self, group, name):
+ return self._behavior(group, name,
+ _CARDINALITY_TO_GENERIC_BLOCKING_BEHAVIOR)
- def future(self, group, name):
- return self._behavior(group, name, _CARDINALITY_TO_GENERIC_FUTURE_BEHAVIOR)
+ def future(self, group, name):
+ return self._behavior(group, name,
+ _CARDINALITY_TO_GENERIC_FUTURE_BEHAVIOR)
- def event(self, group, name):
- return self._behavior(group, name, _CARDINALITY_TO_GENERIC_EVENT_BEHAVIOR)
+ def event(self, group, name):
+ return self._behavior(group, name,
+ _CARDINALITY_TO_GENERIC_EVENT_BEHAVIOR)
class _GenericInvokerConstructor(InvokerConstructor):
- def name(self):
- return 'GenericInvoker'
+ def name(self):
+ return 'GenericInvoker'
- def construct_invoker(self, generic_stub, dynamic_stub, methods):
- return _GenericInvoker(generic_stub, methods)
+ def construct_invoker(self, generic_stub, dynamic_stub, methods):
+ return _GenericInvoker(generic_stub, methods)
class _MultiCallableInvoker(Invoker):
- def __init__(self, generic_stub, methods):
- self._stub = generic_stub
- self._methods = methods
+ def __init__(self, generic_stub, methods):
+ self._stub = generic_stub
+ self._methods = methods
- def _multi_callable(self, group, name):
- method_cardinality = self._methods[group, name].cardinality()
- behavior = getattr(
- self._stub,
- _CARDINALITY_TO_MULTI_CALLABLE_ATTRIBUTE[method_cardinality])
- return behavior(group, name)
+ def _multi_callable(self, group, name):
+ method_cardinality = self._methods[group, name].cardinality()
+ behavior = getattr(
+ self._stub,
+ _CARDINALITY_TO_MULTI_CALLABLE_ATTRIBUTE[method_cardinality])
+ return behavior(group, name)
- def blocking(self, group, name):
- return self._multi_callable(group, name)
+ def blocking(self, group, name):
+ return self._multi_callable(group, name)
- def future(self, group, name):
- method_cardinality = self._methods[group, name].cardinality()
- behavior = getattr(
- self._stub,
- _CARDINALITY_TO_MULTI_CALLABLE_ATTRIBUTE[method_cardinality])
- if method_cardinality in (
- cardinality.Cardinality.UNARY_UNARY,
- cardinality.Cardinality.STREAM_UNARY):
- return behavior(group, name).future
- else:
- return behavior(group, name)
+ def future(self, group, name):
+ method_cardinality = self._methods[group, name].cardinality()
+ behavior = getattr(
+ self._stub,
+ _CARDINALITY_TO_MULTI_CALLABLE_ATTRIBUTE[method_cardinality])
+ if method_cardinality in (cardinality.Cardinality.UNARY_UNARY,
+ cardinality.Cardinality.STREAM_UNARY):
+ return behavior(group, name).future
+ else:
+ return behavior(group, name)
- def event(self, group, name):
- return self._multi_callable(group, name).event
+ def event(self, group, name):
+ return self._multi_callable(group, name).event
class _MultiCallableInvokerConstructor(InvokerConstructor):
- def name(self):
- return 'MultiCallableInvoker'
+ def name(self):
+ return 'MultiCallableInvoker'
- def construct_invoker(self, generic_stub, dynamic_stub, methods):
- return _MultiCallableInvoker(generic_stub, methods)
+ def construct_invoker(self, generic_stub, dynamic_stub, methods):
+ return _MultiCallableInvoker(generic_stub, methods)
class _DynamicInvoker(Invoker):
- def __init__(self, dynamic_stubs, methods):
- self._stubs = dynamic_stubs
- self._methods = methods
+ def __init__(self, dynamic_stubs, methods):
+ self._stubs = dynamic_stubs
+ self._methods = methods
- def blocking(self, group, name):
- return getattr(self._stubs[group], name)
+ def blocking(self, group, name):
+ return getattr(self._stubs[group], name)
- def future(self, group, name):
- if self._methods[group, name].cardinality() in (
- cardinality.Cardinality.UNARY_UNARY,
- cardinality.Cardinality.STREAM_UNARY):
- return getattr(self._stubs[group], name).future
- else:
- return getattr(self._stubs[group], name)
+ def future(self, group, name):
+ if self._methods[group, name].cardinality() in (
+ cardinality.Cardinality.UNARY_UNARY,
+ cardinality.Cardinality.STREAM_UNARY):
+ return getattr(self._stubs[group], name).future
+ else:
+ return getattr(self._stubs[group], name)
- def event(self, group, name):
- return getattr(self._stubs[group], name).event
+ def event(self, group, name):
+ return getattr(self._stubs[group], name).event
class _DynamicInvokerConstructor(InvokerConstructor):
- def name(self):
- return 'DynamicInvoker'
+ def name(self):
+ return 'DynamicInvoker'
- def construct_invoker(self, generic_stub, dynamic_stubs, methods):
- return _DynamicInvoker(dynamic_stubs, methods)
+ def construct_invoker(self, generic_stub, dynamic_stubs, methods):
+ return _DynamicInvoker(dynamic_stubs, methods)
def invoker_constructors():
- """Creates a sequence of InvokerConstructors to use in tests of RPCs.
+ """Creates a sequence of InvokerConstructors to use in tests of RPCs.
Returns:
A sequence of InvokerConstructors.
"""
- return (
- _GenericInvokerConstructor(),
- _MultiCallableInvokerConstructor(),
- _DynamicInvokerConstructor(),
- )
+ return (
+ _GenericInvokerConstructor(),
+ _MultiCallableInvokerConstructor(),
+ _DynamicInvokerConstructor(),)
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_service.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_service.py
index f13dff0558..f14ac6a987 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_service.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_service.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Private interfaces implemented by data sets used in Face-layer tests."""
import abc
@@ -38,12 +37,13 @@ from grpc.framework.interfaces.face import face # pylint: disable=unused-import
from tests.unit.framework.interfaces.face import test_interfaces
-class UnaryUnaryTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
- """A controllable implementation of a unary-unary method."""
+class UnaryUnaryTestMethodImplementation(
+ six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
+ """A controllable implementation of a unary-unary method."""
- @abc.abstractmethod
- def service(self, request, response_callback, context, control):
- """Services an RPC that accepts one message and produces one message.
+ @abc.abstractmethod
+ def service(self, request, response_callback, context, control):
+ """Services an RPC that accepts one message and produces one message.
Args:
request: The single request message for the RPC.
@@ -56,15 +56,15 @@ class UnaryUnaryTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_in
abandonment.Abandoned: May or may not be raised when the RPC has been
aborted.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class UnaryUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
- """A type for unary-request-unary-response message pairings."""
+ """A type for unary-request-unary-response message pairings."""
- @abc.abstractmethod
- def request(self):
- """Affords a request message.
+ @abc.abstractmethod
+ def request(self):
+ """Affords a request message.
Implementations of this method should return a different message with each
call so that multiple test executions of the test method may be made with
@@ -73,11 +73,11 @@ class UnaryUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
Returns:
A request message.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def verify(self, request, response, test_case):
- """Verifies that the computed response matches the given request.
+ @abc.abstractmethod
+ def verify(self, request, response, test_case):
+ """Verifies that the computed response matches the given request.
Args:
request: A request message.
@@ -88,15 +88,16 @@ class UnaryUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
AssertionError: If the request and response do not match, indicating that
there was some problem executing the RPC under test.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
-class UnaryStreamTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
- """A controllable implementation of a unary-stream method."""
+class UnaryStreamTestMethodImplementation(
+ six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
+ """A controllable implementation of a unary-stream method."""
- @abc.abstractmethod
- def service(self, request, response_consumer, context, control):
- """Services an RPC that takes one message and produces a stream of messages.
+ @abc.abstractmethod
+ def service(self, request, response_consumer, context, control):
+ """Services an RPC that takes one message and produces a stream of messages.
Args:
request: The single request message for the RPC.
@@ -109,15 +110,15 @@ class UnaryStreamTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_i
abandonment.Abandoned: May or may not be raised when the RPC has been
aborted.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class UnaryStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
- """A type for unary-request-stream-response message pairings."""
+ """A type for unary-request-stream-response message pairings."""
- @abc.abstractmethod
- def request(self):
- """Affords a request message.
+ @abc.abstractmethod
+ def request(self):
+ """Affords a request message.
Implementations of this method should return a different message with each
call so that multiple test executions of the test method may be made with
@@ -126,11 +127,11 @@ class UnaryStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
Returns:
A request message.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def verify(self, request, responses, test_case):
- """Verifies that the computed responses match the given request.
+ @abc.abstractmethod
+ def verify(self, request, responses, test_case):
+ """Verifies that the computed responses match the given request.
Args:
request: A request message.
@@ -141,15 +142,16 @@ class UnaryStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
AssertionError: If the request and responses do not match, indicating that
there was some problem executing the RPC under test.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
-class StreamUnaryTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
- """A controllable implementation of a stream-unary method."""
+class StreamUnaryTestMethodImplementation(
+ six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
+ """A controllable implementation of a stream-unary method."""
- @abc.abstractmethod
- def service(self, response_callback, context, control):
- """Services an RPC that takes a stream of messages and produces one message.
+ @abc.abstractmethod
+ def service(self, response_callback, context, control):
+ """Services an RPC that takes a stream of messages and produces one message.
Args:
response_callback: A callback to be called to accept the response message
@@ -169,15 +171,15 @@ class StreamUnaryTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_i
abandonment.Abandoned: May or may not be raised when the RPC has been
aborted.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class StreamUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
- """A type for stream-request-unary-response message pairings."""
+ """A type for stream-request-unary-response message pairings."""
- @abc.abstractmethod
- def requests(self):
- """Affords a sequence of request messages.
+ @abc.abstractmethod
+ def requests(self):
+ """Affords a sequence of request messages.
Implementations of this method should return a different sequences with each
call so that multiple test executions of the test method may be made with
@@ -186,11 +188,11 @@ class StreamUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
Returns:
A sequence of request messages.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def verify(self, requests, response, test_case):
- """Verifies that the computed response matches the given requests.
+ @abc.abstractmethod
+ def verify(self, requests, response, test_case):
+ """Verifies that the computed response matches the given requests.
Args:
requests: A sequence of request messages.
@@ -201,15 +203,16 @@ class StreamUnaryTestMessages(six.with_metaclass(abc.ABCMeta)):
AssertionError: If the requests and response do not match, indicating that
there was some problem executing the RPC under test.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
-class StreamStreamTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
- """A controllable implementation of a stream-stream method."""
+class StreamStreamTestMethodImplementation(
+ six.with_metaclass(abc.ABCMeta, test_interfaces.Method)):
+ """A controllable implementation of a stream-stream method."""
- @abc.abstractmethod
- def service(self, response_consumer, context, control):
- """Services an RPC that accepts and produces streams of messages.
+ @abc.abstractmethod
+ def service(self, response_consumer, context, control):
+ """Services an RPC that accepts and produces streams of messages.
Args:
response_consumer: A stream.Consumer to be called to accept the response
@@ -229,15 +232,15 @@ class StreamStreamTestMethodImplementation(six.with_metaclass(abc.ABCMeta, test_
abandonment.Abandoned: May or may not be raised when the RPC has been
aborted.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class StreamStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
- """A type for stream-request-stream-response message pairings."""
+ """A type for stream-request-stream-response message pairings."""
- @abc.abstractmethod
- def requests(self):
- """Affords a sequence of request messages.
+ @abc.abstractmethod
+ def requests(self):
+ """Affords a sequence of request messages.
Implementations of this method should return a different sequences with each
call so that multiple test executions of the test method may be made with
@@ -246,11 +249,11 @@ class StreamStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
Returns:
A sequence of request messages.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def verify(self, requests, responses, test_case):
- """Verifies that the computed response matches the given requests.
+ @abc.abstractmethod
+ def verify(self, requests, responses, test_case):
+ """Verifies that the computed response matches the given requests.
Args:
requests: A sequence of request messages.
@@ -261,15 +264,15 @@ class StreamStreamTestMessages(six.with_metaclass(abc.ABCMeta)):
AssertionError: If the requests and responses do not match, indicating
that there was some problem executing the RPC under test.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class TestService(six.with_metaclass(abc.ABCMeta)):
- """A specification of implemented methods to use in tests."""
+ """A specification of implemented methods to use in tests."""
- @abc.abstractmethod
- def unary_unary_scenarios(self):
- """Affords unary-request-unary-response test methods and their messages.
+ @abc.abstractmethod
+ def unary_unary_scenarios(self):
+ """Affords unary-request-unary-response test methods and their messages.
Returns:
A dict from method group-name pair to implementation/messages pair. The
@@ -277,11 +280,11 @@ class TestService(six.with_metaclass(abc.ABCMeta)):
and the second element is a sequence of UnaryUnaryTestMethodMessages
objects.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def unary_stream_scenarios(self):
- """Affords unary-request-stream-response test methods and their messages.
+ @abc.abstractmethod
+ def unary_stream_scenarios(self):
+ """Affords unary-request-stream-response test methods and their messages.
Returns:
A dict from method group-name pair to implementation/messages pair. The
@@ -289,11 +292,11 @@ class TestService(six.with_metaclass(abc.ABCMeta)):
object and the second element is a sequence of
UnaryStreamTestMethodMessages objects.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def stream_unary_scenarios(self):
- """Affords stream-request-unary-response test methods and their messages.
+ @abc.abstractmethod
+ def stream_unary_scenarios(self):
+ """Affords stream-request-unary-response test methods and their messages.
Returns:
A dict from method group-name pair to implementation/messages pair. The
@@ -301,11 +304,11 @@ class TestService(six.with_metaclass(abc.ABCMeta)):
object and the second element is a sequence of
StreamUnaryTestMethodMessages objects.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def stream_stream_scenarios(self):
- """Affords stream-request-stream-response test methods and their messages.
+ @abc.abstractmethod
+ def stream_stream_scenarios(self):
+ """Affords stream-request-stream-response test methods and their messages.
Returns:
A dict from method group-name pair to implementation/messages pair. The
@@ -313,4 +316,4 @@ class TestService(six.with_metaclass(abc.ABCMeta)):
object and the second element is a sequence of
StreamStreamTestMethodMessages objects.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py
index 5299655bb3..41a55c13f4 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/_stock_service.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Examples of Python implementations of the stock.proto Stock service."""
from grpc.framework.common import cardinality
@@ -44,353 +43,363 @@ _price = lambda symbol_name: float(hash(symbol_name) % 4096)
def _get_last_trade_price(stock_request, stock_reply_callback, control, active):
- """A unary-request, unary-response test method."""
- control.control()
- if active():
- stock_reply_callback(
- stock_pb2.StockReply(
- symbol=stock_request.symbol, price=_price(stock_request.symbol)))
- else:
- raise abandonment.Abandoned()
-
-
-def _get_last_trade_price_multiple(stock_reply_consumer, control, active):
- """A stream-request, stream-response test method."""
- def stock_reply_for_stock_request(stock_request):
+ """A unary-request, unary-response test method."""
control.control()
if active():
- return stock_pb2.StockReply(
- symbol=stock_request.symbol, price=_price(stock_request.symbol))
+ stock_reply_callback(
+ stock_pb2.StockReply(
+ symbol=stock_request.symbol, price=_price(
+ stock_request.symbol)))
else:
- raise abandonment.Abandoned()
-
- class StockRequestConsumer(stream.Consumer):
+ raise abandonment.Abandoned()
- def consume(self, stock_request):
- stock_reply_consumer.consume(stock_reply_for_stock_request(stock_request))
- def terminate(self):
- control.control()
- stock_reply_consumer.terminate()
+def _get_last_trade_price_multiple(stock_reply_consumer, control, active):
+ """A stream-request, stream-response test method."""
- def consume_and_terminate(self, stock_request):
- stock_reply_consumer.consume_and_terminate(
- stock_reply_for_stock_request(stock_request))
+ def stock_reply_for_stock_request(stock_request):
+ control.control()
+ if active():
+ return stock_pb2.StockReply(
+ symbol=stock_request.symbol, price=_price(stock_request.symbol))
+ else:
+ raise abandonment.Abandoned()
- return StockRequestConsumer()
+ class StockRequestConsumer(stream.Consumer):
+ def consume(self, stock_request):
+ stock_reply_consumer.consume(
+ stock_reply_for_stock_request(stock_request))
-def _watch_future_trades(stock_request, stock_reply_consumer, control, active):
- """A unary-request, stream-response test method."""
- base_price = _price(stock_request.symbol)
- for index in range(stock_request.num_trades_to_watch):
- control.control()
- if active():
- stock_reply_consumer.consume(
- stock_pb2.StockReply(
- symbol=stock_request.symbol, price=base_price + index))
- else:
- raise abandonment.Abandoned()
- stock_reply_consumer.terminate()
+ def terminate(self):
+ control.control()
+ stock_reply_consumer.terminate()
+ def consume_and_terminate(self, stock_request):
+ stock_reply_consumer.consume_and_terminate(
+ stock_reply_for_stock_request(stock_request))
-def _get_highest_trade_price(stock_reply_callback, control, active):
- """A stream-request, unary-response test method."""
+ return StockRequestConsumer()
- class StockRequestConsumer(stream.Consumer):
- """Keeps an ongoing record of the most valuable symbol yet consumed."""
- def __init__(self):
- self._symbol = None
- self._price = None
-
- def consume(self, stock_request):
- control.control()
- if active():
- if self._price is None:
- self._symbol = stock_request.symbol
- self._price = _price(stock_request.symbol)
- else:
- candidate_price = _price(stock_request.symbol)
- if self._price < candidate_price:
- self._symbol = stock_request.symbol
- self._price = candidate_price
-
- def terminate(self):
- control.control()
- if active():
- if self._symbol is None:
- raise ValueError()
- else:
- stock_reply_callback(
- stock_pb2.StockReply(symbol=self._symbol, price=self._price))
- self._symbol = None
- self._price = None
-
- def consume_and_terminate(self, stock_request):
- control.control()
- if active():
- if self._price is None:
- stock_reply_callback(
- stock_pb2.StockReply(
- symbol=stock_request.symbol,
- price=_price(stock_request.symbol)))
- else:
- candidate_price = _price(stock_request.symbol)
- if self._price < candidate_price:
- stock_reply_callback(
- stock_pb2.StockReply(
- symbol=stock_request.symbol, price=candidate_price))
- else:
- stock_reply_callback(
+def _watch_future_trades(stock_request, stock_reply_consumer, control, active):
+ """A unary-request, stream-response test method."""
+ base_price = _price(stock_request.symbol)
+ for index in range(stock_request.num_trades_to_watch):
+ control.control()
+ if active():
+ stock_reply_consumer.consume(
stock_pb2.StockReply(
- symbol=self._symbol, price=self._price))
+ symbol=stock_request.symbol, price=base_price + index))
+ else:
+ raise abandonment.Abandoned()
+ stock_reply_consumer.terminate()
- self._symbol = None
- self._price = None
- return StockRequestConsumer()
+def _get_highest_trade_price(stock_reply_callback, control, active):
+ """A stream-request, unary-response test method."""
+
+ class StockRequestConsumer(stream.Consumer):
+ """Keeps an ongoing record of the most valuable symbol yet consumed."""
+
+ def __init__(self):
+ self._symbol = None
+ self._price = None
+
+ def consume(self, stock_request):
+ control.control()
+ if active():
+ if self._price is None:
+ self._symbol = stock_request.symbol
+ self._price = _price(stock_request.symbol)
+ else:
+ candidate_price = _price(stock_request.symbol)
+ if self._price < candidate_price:
+ self._symbol = stock_request.symbol
+ self._price = candidate_price
+
+ def terminate(self):
+ control.control()
+ if active():
+ if self._symbol is None:
+ raise ValueError()
+ else:
+ stock_reply_callback(
+ stock_pb2.StockReply(
+ symbol=self._symbol, price=self._price))
+ self._symbol = None
+ self._price = None
+
+ def consume_and_terminate(self, stock_request):
+ control.control()
+ if active():
+ if self._price is None:
+ stock_reply_callback(
+ stock_pb2.StockReply(
+ symbol=stock_request.symbol,
+ price=_price(stock_request.symbol)))
+ else:
+ candidate_price = _price(stock_request.symbol)
+ if self._price < candidate_price:
+ stock_reply_callback(
+ stock_pb2.StockReply(
+ symbol=stock_request.symbol,
+ price=candidate_price))
+ else:
+ stock_reply_callback(
+ stock_pb2.StockReply(
+ symbol=self._symbol, price=self._price))
+
+ self._symbol = None
+ self._price = None
+
+ return StockRequestConsumer()
class GetLastTradePrice(_service.UnaryUnaryTestMethodImplementation):
- """GetLastTradePrice for use in tests."""
+ """GetLastTradePrice for use in tests."""
- def group(self):
- return _STOCK_GROUP_NAME
+ def group(self):
+ return _STOCK_GROUP_NAME
- def name(self):
- return 'GetLastTradePrice'
+ def name(self):
+ return 'GetLastTradePrice'
- def cardinality(self):
- return cardinality.Cardinality.UNARY_UNARY
+ def cardinality(self):
+ return cardinality.Cardinality.UNARY_UNARY
- def request_class(self):
- return stock_pb2.StockRequest
+ def request_class(self):
+ return stock_pb2.StockRequest
- def response_class(self):
- return stock_pb2.StockReply
+ def response_class(self):
+ return stock_pb2.StockReply
- def serialize_request(self, request):
- return request.SerializeToString()
+ def serialize_request(self, request):
+ return request.SerializeToString()
- def deserialize_request(self, serialized_request):
- return stock_pb2.StockRequest.FromString(serialized_request)
+ def deserialize_request(self, serialized_request):
+ return stock_pb2.StockRequest.FromString(serialized_request)
- def serialize_response(self, response):
- return response.SerializeToString()
+ def serialize_response(self, response):
+ return response.SerializeToString()
- def deserialize_response(self, serialized_response):
- return stock_pb2.StockReply.FromString(serialized_response)
+ def deserialize_response(self, serialized_response):
+ return stock_pb2.StockReply.FromString(serialized_response)
- def service(self, request, response_callback, context, control):
- _get_last_trade_price(
- request, response_callback, control, context.is_active)
+ def service(self, request, response_callback, context, control):
+ _get_last_trade_price(request, response_callback, control,
+ context.is_active)
class GetLastTradePriceMessages(_service.UnaryUnaryTestMessages):
- def __init__(self):
- self._index = 0
+ def __init__(self):
+ self._index = 0
- def request(self):
- symbol = _SYMBOL_FORMAT % self._index
- self._index += 1
- return stock_pb2.StockRequest(symbol=symbol)
+ def request(self):
+ symbol = _SYMBOL_FORMAT % self._index
+ self._index += 1
+ return stock_pb2.StockRequest(symbol=symbol)
- def verify(self, request, response, test_case):
- test_case.assertEqual(request.symbol, response.symbol)
- test_case.assertEqual(_price(request.symbol), response.price)
+ def verify(self, request, response, test_case):
+ test_case.assertEqual(request.symbol, response.symbol)
+ test_case.assertEqual(_price(request.symbol), response.price)
class GetLastTradePriceMultiple(_service.StreamStreamTestMethodImplementation):
- """GetLastTradePriceMultiple for use in tests."""
+ """GetLastTradePriceMultiple for use in tests."""
- def group(self):
- return _STOCK_GROUP_NAME
+ def group(self):
+ return _STOCK_GROUP_NAME
- def name(self):
- return 'GetLastTradePriceMultiple'
+ def name(self):
+ return 'GetLastTradePriceMultiple'
- def cardinality(self):
- return cardinality.Cardinality.STREAM_STREAM
+ def cardinality(self):
+ return cardinality.Cardinality.STREAM_STREAM
- def request_class(self):
- return stock_pb2.StockRequest
+ def request_class(self):
+ return stock_pb2.StockRequest
- def response_class(self):
- return stock_pb2.StockReply
+ def response_class(self):
+ return stock_pb2.StockReply
- def serialize_request(self, request):
- return request.SerializeToString()
+ def serialize_request(self, request):
+ return request.SerializeToString()
- def deserialize_request(self, serialized_request):
- return stock_pb2.StockRequest.FromString(serialized_request)
+ def deserialize_request(self, serialized_request):
+ return stock_pb2.StockRequest.FromString(serialized_request)
- def serialize_response(self, response):
- return response.SerializeToString()
+ def serialize_response(self, response):
+ return response.SerializeToString()
- def deserialize_response(self, serialized_response):
- return stock_pb2.StockReply.FromString(serialized_response)
+ def deserialize_response(self, serialized_response):
+ return stock_pb2.StockReply.FromString(serialized_response)
- def service(self, response_consumer, context, control):
- return _get_last_trade_price_multiple(
- response_consumer, control, context.is_active)
+ def service(self, response_consumer, context, control):
+ return _get_last_trade_price_multiple(response_consumer, control,
+ context.is_active)
class GetLastTradePriceMultipleMessages(_service.StreamStreamTestMessages):
- """Pairs of message streams for use with GetLastTradePriceMultiple."""
+ """Pairs of message streams for use with GetLastTradePriceMultiple."""
- def __init__(self):
- self._index = 0
+ def __init__(self):
+ self._index = 0
- def requests(self):
- base_index = self._index
- self._index += 1
- return [
- stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % (base_index + index))
- for index in range(test_constants.STREAM_LENGTH)]
+ def requests(self):
+ base_index = self._index
+ self._index += 1
+ return [
+ stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % (base_index + index))
+ for index in range(test_constants.STREAM_LENGTH)
+ ]
- def verify(self, requests, responses, test_case):
- test_case.assertEqual(len(requests), len(responses))
- for stock_request, stock_reply in zip(requests, responses):
- test_case.assertEqual(stock_request.symbol, stock_reply.symbol)
- test_case.assertEqual(_price(stock_request.symbol), stock_reply.price)
+ def verify(self, requests, responses, test_case):
+ test_case.assertEqual(len(requests), len(responses))
+ for stock_request, stock_reply in zip(requests, responses):
+ test_case.assertEqual(stock_request.symbol, stock_reply.symbol)
+ test_case.assertEqual(
+ _price(stock_request.symbol), stock_reply.price)
class WatchFutureTrades(_service.UnaryStreamTestMethodImplementation):
- """WatchFutureTrades for use in tests."""
+ """WatchFutureTrades for use in tests."""
- def group(self):
- return _STOCK_GROUP_NAME
+ def group(self):
+ return _STOCK_GROUP_NAME
- def name(self):
- return 'WatchFutureTrades'
+ def name(self):
+ return 'WatchFutureTrades'
- def cardinality(self):
- return cardinality.Cardinality.UNARY_STREAM
+ def cardinality(self):
+ return cardinality.Cardinality.UNARY_STREAM
- def request_class(self):
- return stock_pb2.StockRequest
+ def request_class(self):
+ return stock_pb2.StockRequest
- def response_class(self):
- return stock_pb2.StockReply
+ def response_class(self):
+ return stock_pb2.StockReply
- def serialize_request(self, request):
- return request.SerializeToString()
+ def serialize_request(self, request):
+ return request.SerializeToString()
- def deserialize_request(self, serialized_request):
- return stock_pb2.StockRequest.FromString(serialized_request)
+ def deserialize_request(self, serialized_request):
+ return stock_pb2.StockRequest.FromString(serialized_request)
- def serialize_response(self, response):
- return response.SerializeToString()
+ def serialize_response(self, response):
+ return response.SerializeToString()
- def deserialize_response(self, serialized_response):
- return stock_pb2.StockReply.FromString(serialized_response)
+ def deserialize_response(self, serialized_response):
+ return stock_pb2.StockReply.FromString(serialized_response)
- def service(self, request, response_consumer, context, control):
- _watch_future_trades(request, response_consumer, control, context.is_active)
+ def service(self, request, response_consumer, context, control):
+ _watch_future_trades(request, response_consumer, control,
+ context.is_active)
class WatchFutureTradesMessages(_service.UnaryStreamTestMessages):
- """Pairs of a single request message and a sequence of response messages."""
+ """Pairs of a single request message and a sequence of response messages."""
- def __init__(self):
- self._index = 0
+ def __init__(self):
+ self._index = 0
- def request(self):
- symbol = _SYMBOL_FORMAT % self._index
- self._index += 1
- return stock_pb2.StockRequest(
- symbol=symbol, num_trades_to_watch=test_constants.STREAM_LENGTH)
+ def request(self):
+ symbol = _SYMBOL_FORMAT % self._index
+ self._index += 1
+ return stock_pb2.StockRequest(
+ symbol=symbol, num_trades_to_watch=test_constants.STREAM_LENGTH)
- def verify(self, request, responses, test_case):
- test_case.assertEqual(test_constants.STREAM_LENGTH, len(responses))
- base_price = _price(request.symbol)
- for index, response in enumerate(responses):
- test_case.assertEqual(base_price + index, response.price)
+ def verify(self, request, responses, test_case):
+ test_case.assertEqual(test_constants.STREAM_LENGTH, len(responses))
+ base_price = _price(request.symbol)
+ for index, response in enumerate(responses):
+ test_case.assertEqual(base_price + index, response.price)
class GetHighestTradePrice(_service.StreamUnaryTestMethodImplementation):
- """GetHighestTradePrice for use in tests."""
+ """GetHighestTradePrice for use in tests."""
- def group(self):
- return _STOCK_GROUP_NAME
+ def group(self):
+ return _STOCK_GROUP_NAME
- def name(self):
- return 'GetHighestTradePrice'
+ def name(self):
+ return 'GetHighestTradePrice'
- def cardinality(self):
- return cardinality.Cardinality.STREAM_UNARY
+ def cardinality(self):
+ return cardinality.Cardinality.STREAM_UNARY
- def request_class(self):
- return stock_pb2.StockRequest
+ def request_class(self):
+ return stock_pb2.StockRequest
- def response_class(self):
- return stock_pb2.StockReply
+ def response_class(self):
+ return stock_pb2.StockReply
- def serialize_request(self, request):
- return request.SerializeToString()
+ def serialize_request(self, request):
+ return request.SerializeToString()
- def deserialize_request(self, serialized_request):
- return stock_pb2.StockRequest.FromString(serialized_request)
+ def deserialize_request(self, serialized_request):
+ return stock_pb2.StockRequest.FromString(serialized_request)
- def serialize_response(self, response):
- return response.SerializeToString()
+ def serialize_response(self, response):
+ return response.SerializeToString()
- def deserialize_response(self, serialized_response):
- return stock_pb2.StockReply.FromString(serialized_response)
+ def deserialize_response(self, serialized_response):
+ return stock_pb2.StockReply.FromString(serialized_response)
- def service(self, response_callback, context, control):
- return _get_highest_trade_price(
- response_callback, control, context.is_active)
+ def service(self, response_callback, context, control):
+ return _get_highest_trade_price(response_callback, control,
+ context.is_active)
class GetHighestTradePriceMessages(_service.StreamUnaryTestMessages):
- def requests(self):
- return [
- stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % index)
- for index in range(test_constants.STREAM_LENGTH)]
-
- def verify(self, requests, response, test_case):
- price = None
- symbol = None
- for stock_request in requests:
- current_symbol = stock_request.symbol
- current_price = _price(current_symbol)
- if price is None or price < current_price:
- price = current_price
- symbol = current_symbol
- test_case.assertEqual(price, response.price)
- test_case.assertEqual(symbol, response.symbol)
+ def requests(self):
+ return [
+ stock_pb2.StockRequest(symbol=_SYMBOL_FORMAT % index)
+ for index in range(test_constants.STREAM_LENGTH)
+ ]
+
+ def verify(self, requests, response, test_case):
+ price = None
+ symbol = None
+ for stock_request in requests:
+ current_symbol = stock_request.symbol
+ current_price = _price(current_symbol)
+ if price is None or price < current_price:
+ price = current_price
+ symbol = current_symbol
+ test_case.assertEqual(price, response.price)
+ test_case.assertEqual(symbol, response.symbol)
class StockTestService(_service.TestService):
- """A corpus of test data with one method of each RPC cardinality."""
-
- def unary_unary_scenarios(self):
- return {
- (_STOCK_GROUP_NAME, 'GetLastTradePrice'): (
- GetLastTradePrice(), [GetLastTradePriceMessages()]),
- }
-
- def unary_stream_scenarios(self):
- return {
- (_STOCK_GROUP_NAME, 'WatchFutureTrades'): (
- WatchFutureTrades(), [WatchFutureTradesMessages()]),
- }
-
- def stream_unary_scenarios(self):
- return {
- (_STOCK_GROUP_NAME, 'GetHighestTradePrice'): (
- GetHighestTradePrice(), [GetHighestTradePriceMessages()])
- }
-
- def stream_stream_scenarios(self):
- return {
- (_STOCK_GROUP_NAME, 'GetLastTradePriceMultiple'): (
- GetLastTradePriceMultiple(), [GetLastTradePriceMultipleMessages()]),
- }
+ """A corpus of test data with one method of each RPC cardinality."""
+
+ def unary_unary_scenarios(self):
+ return {
+ (_STOCK_GROUP_NAME, 'GetLastTradePrice'):
+ (GetLastTradePrice(), [GetLastTradePriceMessages()]),
+ }
+
+ def unary_stream_scenarios(self):
+ return {
+ (_STOCK_GROUP_NAME, 'WatchFutureTrades'):
+ (WatchFutureTrades(), [WatchFutureTradesMessages()]),
+ }
+
+ def stream_unary_scenarios(self):
+ return {
+ (_STOCK_GROUP_NAME, 'GetHighestTradePrice'):
+ (GetHighestTradePrice(), [GetHighestTradePriceMessages()])
+ }
+
+ def stream_stream_scenarios(self):
+ return {
+ (_STOCK_GROUP_NAME, 'GetLastTradePriceMultiple'):
+ (GetLastTradePriceMultiple(),
+ [GetLastTradePriceMultipleMessages()]),
+ }
STOCK_TEST_SERVICE = StockTestService()
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py
index 71de9d835e..d84e1fc136 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_cases.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Tools for creating tests of implementations of the Face layer."""
# unittest is referenced from specification in this module.
@@ -40,12 +39,11 @@ from tests.unit.framework.interfaces.face import test_interfaces # pylint: disa
_TEST_CASE_SUPERCLASSES = (
_blocking_invocation_inline_service.TestCase,
- _future_invocation_asynchronous_event_service.TestCase,
-)
+ _future_invocation_asynchronous_event_service.TestCase,)
def test_cases(implementation):
- """Creates unittest.TestCase classes for a given Face layer implementation.
+ """Creates unittest.TestCase classes for a given Face layer implementation.
Args:
implementation: A test_interfaces.Implementation specifying creation and
@@ -55,13 +53,14 @@ def test_cases(implementation):
A sequence of subclasses of unittest.TestCase defining tests of the
specified Face layer implementation.
"""
- test_case_classes = []
- for invoker_constructor in _invocation.invoker_constructors():
- for super_class in _TEST_CASE_SUPERCLASSES:
- test_case_classes.append(
- type(invoker_constructor.name() + super_class.NAME, (super_class,),
- {'implementation': implementation,
- 'invoker_constructor': invoker_constructor,
- '__module__': implementation.__module__,
- }))
- return test_case_classes
+ test_case_classes = []
+ for invoker_constructor in _invocation.invoker_constructors():
+ for super_class in _TEST_CASE_SUPERCLASSES:
+ test_case_classes.append(
+ type(invoker_constructor.name() + super_class.NAME, (
+ super_class,), {
+ 'implementation': implementation,
+ 'invoker_constructor': invoker_constructor,
+ '__module__': implementation.__module__,
+ }))
+ return test_case_classes
diff --git a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_interfaces.py b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_interfaces.py
index 40f38e68ba..a789d435b4 100644
--- a/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_interfaces.py
+++ b/src/python/grpcio_tests/tests/unit/framework/interfaces/face/test_interfaces.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Interfaces used in tests of implementations of the Face layer."""
import abc
@@ -38,103 +37,102 @@ from grpc.framework.interfaces.face import face # pylint: disable=unused-import
class Method(six.with_metaclass(abc.ABCMeta)):
- """Specifies a method to be used in tests."""
+ """Specifies a method to be used in tests."""
- @abc.abstractmethod
- def group(self):
- """Identify the group of the method.
+ @abc.abstractmethod
+ def group(self):
+ """Identify the group of the method.
Returns:
The group of the method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def name(self):
- """Identify the name of the method.
+ @abc.abstractmethod
+ def name(self):
+ """Identify the name of the method.
Returns:
The name of the method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def cardinality(self):
- """Identify the cardinality of the method.
+ @abc.abstractmethod
+ def cardinality(self):
+ """Identify the cardinality of the method.
Returns:
A cardinality.Cardinality value describing the streaming semantics of the
method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def request_class(self):
- """Identify the class used for the method's request objects.
+ @abc.abstractmethod
+ def request_class(self):
+ """Identify the class used for the method's request objects.
Returns:
The class object of the class to which the method's request objects
belong.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def response_class(self):
- """Identify the class used for the method's response objects.
+ @abc.abstractmethod
+ def response_class(self):
+ """Identify the class used for the method's response objects.
Returns:
The class object of the class to which the method's response objects
belong.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def serialize_request(self, request):
- """Serialize the given request object.
+ @abc.abstractmethod
+ def serialize_request(self, request):
+ """Serialize the given request object.
Args:
request: A request object appropriate for this method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def deserialize_request(self, serialized_request):
- """Synthesize a request object from a given bytestring.
+ @abc.abstractmethod
+ def deserialize_request(self, serialized_request):
+ """Synthesize a request object from a given bytestring.
Args:
serialized_request: A bytestring deserializable into a request object
appropriate for this method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def serialize_response(self, response):
- """Serialize the given response object.
+ @abc.abstractmethod
+ def serialize_response(self, response):
+ """Serialize the given response object.
Args:
response: A response object appropriate for this method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def deserialize_response(self, serialized_response):
- """Synthesize a response object from a given bytestring.
+ @abc.abstractmethod
+ def deserialize_response(self, serialized_response):
+ """Synthesize a response object from a given bytestring.
Args:
serialized_response: A bytestring deserializable into a response object
appropriate for this method.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
class Implementation(six.with_metaclass(abc.ABCMeta)):
- """Specifies an implementation of the Face layer."""
+ """Specifies an implementation of the Face layer."""
- @abc.abstractmethod
- def instantiate(
- self, methods, method_implementations,
- multi_method_implementation):
- """Instantiates the Face layer implementation to be used in a test.
+ @abc.abstractmethod
+ def instantiate(self, methods, method_implementations,
+ multi_method_implementation):
+ """Instantiates the Face layer implementation to be used in a test.
Args:
methods: A sequence of Method objects describing the methods available to
@@ -151,69 +149,69 @@ class Implementation(six.with_metaclass(abc.ABCMeta)):
passed to destantiate at the conclusion of the test. The returned stubs
must be backed by the provided implementations.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def destantiate(self, memo):
- """Destroys the Face layer implementation under test.
+ @abc.abstractmethod
+ def destantiate(self, memo):
+ """Destroys the Face layer implementation under test.
Args:
memo: The object from the third position of the return value of a call to
instantiate.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def invocation_metadata(self):
- """Provides the metadata to be used when invoking a test RPC.
+ @abc.abstractmethod
+ def invocation_metadata(self):
+ """Provides the metadata to be used when invoking a test RPC.
Returns:
An object to use as the supplied-at-invocation-time metadata in a test
RPC.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def initial_metadata(self):
- """Provides the metadata for use as a test RPC's first servicer metadata.
+ @abc.abstractmethod
+ def initial_metadata(self):
+ """Provides the metadata for use as a test RPC's first servicer metadata.
Returns:
An object to use as the from-the-servicer-before-responses metadata in a
test RPC.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def terminal_metadata(self):
- """Provides the metadata for use as a test RPC's second servicer metadata.
+ @abc.abstractmethod
+ def terminal_metadata(self):
+ """Provides the metadata for use as a test RPC's second servicer metadata.
Returns:
An object to use as the from-the-servicer-after-all-responses metadata in
a test RPC.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def code(self):
- """Provides the value for use as a test RPC's code.
+ @abc.abstractmethod
+ def code(self):
+ """Provides the value for use as a test RPC's code.
Returns:
An object to use as the from-the-servicer code in a test RPC.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def details(self):
- """Provides the value for use as a test RPC's details.
+ @abc.abstractmethod
+ def details(self):
+ """Provides the value for use as a test RPC's details.
Returns:
An object to use as the from-the-servicer details in a test RPC.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
- @abc.abstractmethod
- def metadata_transmitted(self, original_metadata, transmitted_metadata):
- """Identifies whether or not metadata was properly transmitted.
+ @abc.abstractmethod
+ def metadata_transmitted(self, original_metadata, transmitted_metadata):
+ """Identifies whether or not metadata was properly transmitted.
Args:
original_metadata: A metadata value passed to the Face interface
@@ -226,4 +224,4 @@ class Implementation(six.with_metaclass(abc.ABCMeta)):
Whether or not the metadata was properly transmitted by the Face interface
implementation under test.
"""
- raise NotImplementedError()
+ raise NotImplementedError()
diff --git a/src/python/grpcio_tests/tests/unit/resources.py b/src/python/grpcio_tests/tests/unit/resources.py
index 023cdb155f..55a2fff979 100644
--- a/src/python/grpcio_tests/tests/unit/resources.py
+++ b/src/python/grpcio_tests/tests/unit/resources.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Constants and functions for data used in interoperability testing."""
import os
@@ -39,14 +38,14 @@ _CERTIFICATE_CHAIN_RESOURCE_PATH = 'credentials/server1.pem'
def test_root_certificates():
- return pkg_resources.resource_string(
- __name__, _ROOT_CERTIFICATES_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__,
+ _ROOT_CERTIFICATES_RESOURCE_PATH)
def private_key():
- return pkg_resources.resource_string(__name__, _PRIVATE_KEY_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__, _PRIVATE_KEY_RESOURCE_PATH)
def certificate_chain():
- return pkg_resources.resource_string(
- __name__, _CERTIFICATE_CHAIN_RESOURCE_PATH)
+ return pkg_resources.resource_string(__name__,
+ _CERTIFICATE_CHAIN_RESOURCE_PATH)
diff --git a/src/python/grpcio_tests/tests/unit/test_common.py b/src/python/grpcio_tests/tests/unit/test_common.py
index cd71bd80d7..00fbe0567a 100644
--- a/src/python/grpcio_tests/tests/unit/test_common.py
+++ b/src/python/grpcio_tests/tests/unit/test_common.py
@@ -26,7 +26,6 @@
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
"""Common code used throughout tests of gRPC."""
import collections
@@ -34,14 +33,23 @@ import collections
import grpc
import six
-INVOCATION_INITIAL_METADATA = (('0', 'abc'), ('1', 'def'), ('2', 'ghi'),)
-SERVICE_INITIAL_METADATA = (('3', 'jkl'), ('4', 'mno'), ('5', 'pqr'),)
-SERVICE_TERMINAL_METADATA = (('6', 'stu'), ('7', 'vwx'), ('8', 'yza'),)
+INVOCATION_INITIAL_METADATA = (
+ ('0', 'abc'),
+ ('1', 'def'),
+ ('2', 'ghi'),)
+SERVICE_INITIAL_METADATA = (
+ ('3', 'jkl'),
+ ('4', 'mno'),
+ ('5', 'pqr'),)
+SERVICE_TERMINAL_METADATA = (
+ ('6', 'stu'),
+ ('7', 'vwx'),
+ ('8', 'yza'),)
DETAILS = 'test details'
def metadata_transmitted(original_metadata, transmitted_metadata):
- """Judges whether or not metadata was acceptably transmitted.
+ """Judges whether or not metadata was acceptably transmitted.
gRPC is allowed to insert key-value pairs into the metadata values given by
applications and to reorder key-value pairs with different keys but it is not
@@ -59,31 +67,30 @@ def metadata_transmitted(original_metadata, transmitted_metadata):
A boolean indicating whether transmitted_metadata accurately reflects
original_metadata after having been transmitted via gRPC.
"""
- original = collections.defaultdict(list)
- for key, value in original_metadata:
- original[key].append(value)
- transmitted = collections.defaultdict(list)
- for key, value in transmitted_metadata:
- transmitted[key].append(value)
+ original = collections.defaultdict(list)
+ for key, value in original_metadata:
+ original[key].append(value)
+ transmitted = collections.defaultdict(list)
+ for key, value in transmitted_metadata:
+ transmitted[key].append(value)
- for key, values in six.iteritems(original):
- transmitted_values = transmitted[key]
- transmitted_iterator = iter(transmitted_values)
- try:
- for value in values:
- while True:
- transmitted_value = next(transmitted_iterator)
- if value == transmitted_value:
- break
- except StopIteration:
- return False
- else:
- return True
+ for key, values in six.iteritems(original):
+ transmitted_values = transmitted[key]
+ transmitted_iterator = iter(transmitted_values)
+ try:
+ for value in values:
+ while True:
+ transmitted_value = next(transmitted_iterator)
+ if value == transmitted_value:
+ break
+ except StopIteration:
+ return False
+ else:
+ return True
-def test_secure_channel(
- target, channel_credentials, server_host_override):
- """Creates an insecure Channel to a remote host.
+def test_secure_channel(target, channel_credentials, server_host_override):
+ """Creates an insecure Channel to a remote host.
Args:
host: The name of the remote host to which to connect.
@@ -96,7 +103,7 @@ def test_secure_channel(
An implementations.Channel to the remote host through which RPCs may be
conducted.
"""
- channel = grpc.secure_channel(
- target, channel_credentials,
- (('grpc.ssl_target_name_override', server_host_override,),))
- return channel
+ channel = grpc.secure_channel(target, channel_credentials, ((
+ 'grpc.ssl_target_name_override',
+ server_host_override,),))
+ return channel