aboutsummaryrefslogtreecommitdiffhomepage
path: root/src
diff options
context:
space:
mode:
authorGravatar Masood Malekghassemi <soltanmm@users.noreply.github.com>2015-09-17 14:50:36 -0700
committerGravatar Masood Malekghassemi <soltanmm@users.noreply.github.com>2015-09-30 15:11:24 -0700
commitd43ad333b8847b596b3e1205e85635f0c513aa08 (patch)
treef30d11f60b6c944d6759115dc3a928c083a98266 /src
parentf37cac7dd2b87bc2e9d44e6311149ff8acb29396 (diff)
Make load_tests protocol tests run via py.test
Diffstat (limited to 'src')
-rw-r--r--src/python/grpcio_test/grpc_test/conftest.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/python/grpcio_test/grpc_test/conftest.py b/src/python/grpcio_test/grpc_test/conftest.py
new file mode 100644
index 0000000000..357320ec64
--- /dev/null
+++ b/src/python/grpcio_test/grpc_test/conftest.py
@@ -0,0 +1,60 @@
+import types
+import unittest
+
+import pytest
+
+
+class LoadTestsSuiteCollector(pytest.Collector):
+
+ def __init__(self, name, parent, suite):
+ super(LoadTestsSuiteCollector, self).__init__(name, parent=parent)
+ self.suite = suite
+ self.obj = suite
+
+ def collect(self):
+ collected = []
+ for case in self.suite:
+ if isinstance(case, unittest.TestCase):
+ collected.append(LoadTestsCase(case.id(), self, case))
+ elif isinstance(case, unittest.TestSuite):
+ collected.append(
+ LoadTestsSuiteCollector('suite_child_of_mine', self, case))
+ return collected
+
+ def reportinfo(self):
+ return str(self.suite)
+
+
+class LoadTestsCase(pytest.Function):
+
+ def __init__(self, name, parent, item):
+ super(LoadTestsCase, self).__init__(name, parent, callobj=self._item_run)
+ self.item = item
+
+ def _item_run(self):
+ result = unittest.TestResult()
+ self.item(result)
+ if result.failures:
+ test_method, trace = result.failures[0]
+ pytest.fail(trace, False)
+ elif result.errors:
+ test_method, trace = result.errors[0]
+ pytest.fail(trace, False)
+ elif result.skipped:
+ test_method, reason = result.skipped[0]
+ pytest.skip(reason)
+
+
+def pytest_pycollect_makeitem(collector, name, obj):
+ if name == 'load_tests' and isinstance(obj, types.FunctionType):
+ suite = unittest.TestSuite()
+ loader = unittest.TestLoader()
+ pattern = '*'
+ try:
+ # Check that the 'load_tests' object is actually a callable that actually
+ # accepts the arguments expected for the load_tests protocol.
+ suite = obj(loader, suite, pattern)
+ except Exception as e:
+ return None
+ else:
+ return LoadTestsSuiteCollector(name, collector, suite)