aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/py/abseil/absl/testing/absltest.py
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/py/abseil/absl/testing/absltest.py')
-rw-r--r--third_party/py/abseil/absl/testing/absltest.py1715
1 files changed, 1715 insertions, 0 deletions
diff --git a/third_party/py/abseil/absl/testing/absltest.py b/third_party/py/abseil/absl/testing/absltest.py
new file mode 100644
index 0000000000..8702bfd9d7
--- /dev/null
+++ b/third_party/py/abseil/absl/testing/absltest.py
@@ -0,0 +1,1715 @@
+# Copyright 2017 The Abseil Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Base functionality for Abseil Python tests.
+
+This module contains base classes and high-level functions for Abseil-style
+tests.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import difflib
+import errno
+import getpass
+import inspect
+import itertools
+import json
+import os
+import random
+import re
+import shlex
+import signal
+import subprocess
+import sys
+import tempfile
+import textwrap
+import unittest
+
+try:
+ import faulthandler
+except ImportError:
+ # We use faulthandler if it is available.
+ faulthandler = None
+
+from absl import app
+from absl import flags
+from absl import logging
+from absl.testing import xml_reporter
+import six
+from six.moves import urllib
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+FLAGS = flags.FLAGS
+
+_TEXT_OR_BINARY_TYPES = (six.text_type, six.binary_type)
+
+
+# Many of the methods in this module have names like assertSameElements.
+# This kind of name does not comply with PEP8 style,
+# but it is consistent with the naming of methods in unittest.py.
+# pylint: disable=invalid-name
+
+
+def _get_default_test_random_seed():
+ random_seed = 301
+ value = os.environ.get('TEST_RANDOM_SEED', '')
+ try:
+ random_seed = int(value)
+ except ValueError:
+ pass
+ return random_seed
+
+
+def get_default_test_srcdir():
+ """Returns default test source dir."""
+ return os.environ.get('TEST_SRCDIR', '')
+
+
+def get_default_test_tmpdir():
+ """Returns default test temp dir."""
+ tmpdir = os.environ.get('TEST_TMPDIR', '')
+ if not tmpdir:
+ tmpdir = os.path.join(tempfile.gettempdir(), 'absl_testing')
+
+ return tmpdir
+
+
+def _get_default_randomize_ordering_seed():
+ """Returns default seed to use for randomizing test order.
+
+ This function first checks the --test_randomize_ordering_seed flag, and then
+ the TEST_RANDOMIZE_ORDERING_SEED environment variable. If the first value
+ we find is:
+ * (not set): disable test randomization
+ * 0: disable test randomization
+ * 'random': choose a random seed in [1, 4294967295] for test order
+ randomization
+ * positive integer: use this seed for test order randomization
+
+ (The values used are patterned after
+ https://docs.python.org/3/using/cmdline.html#envvar-PYTHONHASHSEED).
+
+ In principle, it would be simpler to return None if no override is provided;
+ however, the python random module has no `get_seed()`, only `getstate()`,
+ which returns far more data than we want to pass via an environment variable
+ or flag.
+
+ Returns:
+ A default value for test case randomization (int). 0 means do not randomize.
+
+ Raises:
+ ValueError: Raised when the flag or env value is not one of the options
+ above.
+ """
+ if FLAGS.test_randomize_ordering_seed is not None:
+ randomize = FLAGS.test_randomize_ordering_seed
+ else:
+ randomize = os.environ.get('TEST_RANDOMIZE_ORDERING_SEED')
+ if randomize is None:
+ return 0
+ if randomize == 'random':
+ return random.Random().randint(1, 4294967295)
+ if randomize == '0':
+ return 0
+ try:
+ seed = int(randomize)
+ if seed > 0:
+ return seed
+ except ValueError:
+ pass
+ raise ValueError(
+ 'Unknown test randomization seed value: {}'.format(randomize))
+
+
+flags.DEFINE_integer('test_random_seed', _get_default_test_random_seed(),
+ 'Random seed for testing. Some test frameworks may '
+ 'change the default value of this flag between runs, so '
+ 'it is not appropriate for seeding probabilistic tests.',
+ allow_override_cpp=True)
+flags.DEFINE_string('test_srcdir',
+ get_default_test_srcdir(),
+ 'Root of directory tree where source files live',
+ allow_override_cpp=True)
+flags.DEFINE_string('test_tmpdir', get_default_test_tmpdir(),
+ 'Directory for temporary testing files',
+ allow_override_cpp=True)
+flags.DEFINE_string('test_randomize_ordering_seed', None,
+ 'If positive, use this as a seed to randomize the '
+ 'execution order for test cases. If "random", pick a '
+ 'random seed to use. If 0 or not set, do not randomize '
+ 'test case execution order. This flag also overrides '
+ 'the TEST_RANDOMIZE_ORDERING_SEED environment variable.')
+flags.DEFINE_string('xml_output_file', '',
+ 'File to store XML test results')
+
+
+# We might need to monkey-patch TestResult so that it stops considering an
+# unexpected pass as a as a "successful result". For details, see
+# http://bugs.python.org/issue20165
+def _monkey_patch_test_result_for_unexpected_passes():
+ """Workaround for <http://bugs.python.org/issue20165>."""
+
+ def wasSuccessful(self):
+ """Tells whether or not this result was a success.
+
+ Any unexpected pass is to be counted as a non-success.
+
+ Args:
+ self: The TestResult instance.
+
+ Returns:
+ Whether or not this result was a success.
+ """
+ return (len(self.failures) == len(self.errors) ==
+ len(self.unexpectedSuccesses) == 0)
+
+ test_result = unittest.result.TestResult()
+ test_result.addUnexpectedSuccess('test')
+ if test_result.wasSuccessful(): # The bug is present.
+ unittest.result.TestResult.wasSuccessful = wasSuccessful
+ if test_result.wasSuccessful(): # Warn the user if our hot-fix failed.
+ sys.stderr.write('unittest.result.TestResult monkey patch to report'
+ ' unexpected passes as failures did not work.\n')
+
+
+_monkey_patch_test_result_for_unexpected_passes()
+
+
+class TestCase(unittest.TestCase):
+ """Extension of unittest.TestCase providing more powerful assertions."""
+
+ maxDiff = 80 * 20
+
+ def shortDescription(self):
+ """Formats both the test method name and the first line of its docstring.
+
+ If no docstring is given, only returns the method name.
+
+ This method overrides unittest.TestCase.shortDescription(), which
+ only returns the first line of the docstring, obscuring the name
+ of the test upon failure.
+
+ Returns:
+ desc: A short description of a test method.
+ """
+ desc = str(self)
+ # NOTE: super() is used here instead of directly invoking
+ # unittest.TestCase.shortDescription(self), because of the
+ # following line that occurs later on:
+ # unittest.TestCase = TestCase
+ # Because of this, direct invocation of what we think is the
+ # superclass will actually cause infinite recursion.
+ doc_first_line = super(TestCase, self).shortDescription()
+ if doc_first_line is not None:
+ desc = '\n'.join((desc, doc_first_line))
+ return desc
+
+ def assertStartsWith(self, actual, expected_start, msg=None):
+ """Asserts that actual.startswith(expected_start) is True.
+
+ Args:
+ actual: str
+ expected_start: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.startswith(expected_start):
+ self.fail('%r does not start with %r' % (actual, expected_start), msg)
+
+ def assertNotStartsWith(self, actual, unexpected_start, msg=None):
+ """Asserts that actual.startswith(unexpected_start) is False.
+
+ Args:
+ actual: str
+ unexpected_start: str
+ msg: Optional message to report on failure.
+ """
+ if actual.startswith(unexpected_start):
+ self.fail('%r does start with %r' % (actual, unexpected_start), msg)
+
+ def assertEndsWith(self, actual, expected_end, msg=None):
+ """Asserts that actual.endswith(expected_end) is True.
+
+ Args:
+ actual: str
+ expected_end: str
+ msg: Optional message to report on failure.
+ """
+ if not actual.endswith(expected_end):
+ self.fail('%r does not end with %r' % (actual, expected_end), msg)
+
+ def assertNotEndsWith(self, actual, unexpected_end, msg=None):
+ """Asserts that actual.endswith(unexpected_end) is False.
+
+ Args:
+ actual: str
+ unexpected_end: str
+ msg: Optional message to report on failure.
+ """
+ if actual.endswith(unexpected_end):
+ self.fail('%r does end with %r' % (actual, unexpected_end), msg)
+
+ def assertSequenceStartsWith(self, prefix, whole, msg=None):
+ """An equality assertion for the beginning of ordered sequences.
+
+ If prefix is an empty sequence, it will raise an error unless whole is also
+ an empty sequence.
+
+ If prefix is not a sequence, it will raise an error if the first element of
+ whole does not match.
+
+ Args:
+ prefix: A sequence expected at the beginning of the whole parameter.
+ whole: The sequence in which to look for prefix.
+ msg: Optional message to report on failure.
+ """
+ try:
+ prefix_len = len(prefix)
+ except (TypeError, NotImplementedError):
+ prefix = [prefix]
+ prefix_len = 1
+
+ try:
+ whole_len = len(whole)
+ except (TypeError, NotImplementedError):
+ self.fail('For whole: len(%s) is not supported, it appears to be type: '
+ '%s' % (whole, type(whole)), msg)
+
+ assert prefix_len <= whole_len, self._formatMessage(
+ msg,
+ 'Prefix length (%d) is longer than whole length (%d).' %
+ (prefix_len, whole_len)
+ )
+
+ if not prefix_len and whole_len:
+ self.fail('Prefix length is 0 but whole length is %d: %s' %
+ (len(whole), whole), msg)
+
+ try:
+ self.assertSequenceEqual(prefix, whole[:prefix_len], msg)
+ except AssertionError:
+ self.fail('prefix: %s not found at start of whole: %s.' %
+ (prefix, whole), msg)
+
+ def assertEmpty(self, container, msg=None):
+ """Asserts that an object has zero length.
+
+ Args:
+ container: Anything that implements the collections.Sized interface.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, collections.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+
+ # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
+ # have strange __nonzero__/__bool__ behavior.
+ if len(container): # pylint: disable=g-explicit-length-test
+ self.fail('{!r} has length of {}.'.format(container, len(container)), msg)
+
+ def assertNotEmpty(self, container, msg=None):
+ """Asserts that an object has non-zero length.
+
+ Args:
+ container: Anything that implements the collections.Sized interface.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, collections.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+
+ # explicitly check the length since some Sized objects (e.g. numpy.ndarray)
+ # have strange __nonzero__/__bool__ behavior.
+ if not len(container): # pylint: disable=g-explicit-length-test
+ self.fail('{!r} has length of 0.'.format(container), msg)
+
+ def assertLen(self, container, expected_len, msg=None):
+ """Asserts that an object has the expected length.
+
+ Args:
+ container: Anything that implements the collections.Sized interface.
+ expected_len: The expected length of the container.
+ msg: Optional message to report on failure.
+ """
+ if not isinstance(container, collections.Sized):
+ self.fail('Expected a Sized object, got: '
+ '{!r}'.format(type(container).__name__), msg)
+ if len(container) != expected_len:
+ container_repr = unittest.util.safe_repr(container)
+ self.fail('{} has length of {}, expected {}.'.format(
+ container_repr, len(container), expected_len), msg)
+
+ def assertSequenceAlmostEqual(self, expected_seq, actual_seq, places=None,
+ msg=None, delta=None):
+ """An approximate equality assertion for ordered sequences.
+
+ Fail if the two sequences are unequal as determined by their value
+ differences rounded to the given number of decimal places (default 7) and
+ comparing to zero, or by comparing that the difference between each value
+ in the two sequences is more than the given delta.
+
+ Note that decimal places (from zero) are usually not the same as significant
+ digits (measured from the most signficant digit).
+
+ If the two sequences compare equal then they will automatically compare
+ almost equal.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ places: The number of decimal places to compare.
+ msg: The message to be printed if the test fails.
+ delta: The OK difference between compared values.
+ """
+ if len(expected_seq) != len(actual_seq):
+ self.fail('Sequence size mismatch: {} vs {}'.format(
+ len(expected_seq), len(actual_seq)), msg)
+
+ err_list = []
+ for idx, (exp_elem, act_elem) in enumerate(zip(expected_seq, actual_seq)):
+ try:
+ self.assertAlmostEqual(exp_elem, act_elem, places=places, msg=msg,
+ delta=delta)
+ except self.failureException as err:
+ err_list.append('At index {}: {}'.format(idx, err))
+
+ if err_list:
+ if len(err_list) > 30:
+ err_list = err_list[:30] + ['...']
+ msg = self._formatMessage(msg, '\n'.join(err_list))
+ self.fail(msg)
+
+ def assertContainsSubset(self, expected_subset, actual_set, msg=None):
+ """Checks whether actual iterable is a superset of expected iterable."""
+ missing = set(expected_subset) - set(actual_set)
+ if not missing:
+ return
+
+ self.fail('Missing elements %s\nExpected: %s\nActual: %s' % (
+ missing, expected_subset, actual_set), msg)
+
+ def assertNoCommonElements(self, expected_seq, actual_seq, msg=None):
+ """Checks whether actual iterable and expected iterable are disjoint."""
+ common = set(expected_seq) & set(actual_seq)
+ if not common:
+ return
+
+ self.fail('Common elements %s\nExpected: %s\nActual: %s' % (
+ common, expected_seq, actual_seq), msg)
+
+ def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
+ """An unordered sequence specific comparison.
+
+ Equivalent to assertCountEqual(). This method is a compatibility layer
+ for Python 3k, since 2to3 does not convert assertItemsEqual() calls into
+ assertCountEqual() calls.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ msg: The message to be printed if the test fails.
+ """
+
+ if not hasattr(super(TestCase, self), 'assertItemsEqual'):
+ # The assertItemsEqual method was renamed assertCountEqual in Python 3.2
+ super(TestCase, self).assertCountEqual(expected_seq, actual_seq, msg)
+ return
+
+ super(TestCase, self).assertItemsEqual(expected_seq, actual_seq, msg)
+
+ def assertCountEqual(self, expected_seq, actual_seq, msg=None):
+ """An unordered sequence specific comparison.
+
+ It asserts that actual_seq and expected_seq have the same element counts.
+ Equivalent to::
+
+ self.assertEqual(Counter(iter(actual_seq)),
+ Counter(iter(expected_seq)))
+
+ Asserts that each element has the same count in both sequences.
+ Example:
+ - [0, 1, 1] and [1, 0, 1] compare equal.
+ - [0, 0, 1] and [0, 1] compare unequal.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ msg: The message to be printed if the test fails.
+
+ """
+ self.assertItemsEqual(expected_seq, actual_seq, msg)
+
+ def assertSameElements(self, expected_seq, actual_seq, msg=None):
+ """Asserts that two sequences have the same elements (in any order).
+
+ This method, unlike assertCountEqual, doesn't care about any
+ duplicates in the expected and actual sequences.
+
+ >> assertSameElements([1, 1, 1, 0, 0, 0], [0, 1])
+ # Doesn't raise an AssertionError
+
+ If possible, you should use assertCountEqual instead of
+ assertSameElements.
+
+ Args:
+ expected_seq: A sequence containing elements we are expecting.
+ actual_seq: The sequence that we are testing.
+ msg: The message to be printed if the test fails.
+ """
+ # `unittest2.TestCase` used to have assertSameElements, but it was
+ # removed in favor of assertItemsEqual. As there's a unit test
+ # that explicitly checks this behavior, I am leaving this method
+ # alone.
+ # Fail on strings: empirically, passing strings to this test method
+ # is almost always a bug. If comparing the character sets of two strings
+ # is desired, cast the inputs to sets or lists explicitly.
+ if (isinstance(expected_seq, _TEXT_OR_BINARY_TYPES) or
+ isinstance(actual_seq, _TEXT_OR_BINARY_TYPES)):
+ self.fail('Passing string/bytes to assertSameElements is usually a bug. '
+ 'Did you mean to use assertEqual?\n'
+ 'Expected: %s\nActual: %s' % (expected_seq, actual_seq))
+ try:
+ expected = dict([(element, None) for element in expected_seq])
+ actual = dict([(element, None) for element in actual_seq])
+ missing = [element for element in expected if element not in actual]
+ unexpected = [element for element in actual if element not in expected]
+ missing.sort()
+ unexpected.sort()
+ except TypeError:
+ # Fall back to slower list-compare if any of the objects are
+ # not hashable.
+ expected = list(expected_seq)
+ actual = list(actual_seq)
+ expected.sort()
+ actual.sort()
+ missing, unexpected = _sorted_list_difference(expected, actual)
+ errors = []
+ if msg:
+ errors.extend((msg, ':\n'))
+ if missing:
+ errors.append('Expected, but missing:\n %r\n' % missing)
+ if unexpected:
+ errors.append('Unexpected, but present:\n %r\n' % unexpected)
+ if missing or unexpected:
+ self.fail(''.join(errors))
+
+ # unittest.TestCase.assertMultiLineEqual works very similarly, but it
+ # has a different error format. However, I find this slightly more readable.
+ def assertMultiLineEqual(self, first, second, msg=None, **kwargs):
+ """Asserts that two multi-line strings are equal."""
+ assert isinstance(first, six.string_types), (
+ 'First argument is not a string: %r' % (first,))
+ assert isinstance(second, six.string_types), (
+ 'Second argument is not a string: %r' % (second,))
+ line_limit = kwargs.pop('line_limit', 0)
+ if kwargs:
+ raise TypeError('Unexpected keyword args {}'.format(tuple(kwargs)))
+
+ if first == second:
+ return
+ if msg:
+ failure_message = [msg + ':\n']
+ else:
+ failure_message = ['\n']
+ if line_limit:
+ line_limit += len(failure_message)
+ for line in difflib.ndiff(first.splitlines(True), second.splitlines(True)):
+ failure_message.append(line)
+ if not line.endswith('\n'):
+ failure_message.append('\n')
+ if line_limit and len(failure_message) > line_limit:
+ n_omitted = len(failure_message) - line_limit
+ failure_message = failure_message[:line_limit]
+ failure_message.append(
+ '(... and {} more delta lines omitted for brevity.)\n'.format(
+ n_omitted))
+
+ raise self.failureException(''.join(failure_message))
+
+ def assertBetween(self, value, minv, maxv, msg=None):
+ """Asserts that value is between minv and maxv (inclusive)."""
+ msg = self._formatMessage(msg,
+ '"%r" unexpectedly not between "%r" and "%r"' %
+ (value, minv, maxv))
+ self.assertTrue(minv <= value, msg)
+ self.assertTrue(maxv >= value, msg)
+
+ def assertRegexMatch(self, actual_str, regexes, message=None):
+ r"""Asserts that at least one regex in regexes matches str.
+
+ If possible you should use assertRegexpMatches, which is a simpler
+ version of this method. assertRegexpMatches takes a single regular
+ expression (a string or re compiled object) instead of a list.
+
+ Notes:
+ 1. This function uses substring matching, i.e. the matching
+ succeeds if *any* substring of the error message matches *any*
+ regex in the list. This is more convenient for the user than
+ full-string matching.
+
+ 2. If regexes is the empty list, the matching will always fail.
+
+ 3. Use regexes=[''] for a regex that will always pass.
+
+ 4. '.' matches any single character *except* the newline. To
+ match any character, use '(.|\n)'.
+
+ 5. '^' matches the beginning of each line, not just the beginning
+ of the string. Similarly, '$' matches the end of each line.
+
+ 6. An exception will be thrown if regexes contains an invalid
+ regex.
+
+ Args:
+ actual_str: The string we try to match with the items in regexes.
+ regexes: The regular expressions we want to match against str.
+ See "Notes" above for detailed notes on how this is interpreted.
+ message: The message to be printed if the test fails.
+ """
+ if isinstance(regexes, _TEXT_OR_BINARY_TYPES):
+ self.fail('regexes is string or bytes; use assertRegexpMatches instead.',
+ message)
+ if not regexes:
+ self.fail('No regexes specified.', message)
+
+ regex_type = type(regexes[0])
+ for regex in regexes[1:]:
+ if type(regex) is not regex_type: # pylint: disable=unidiomatic-typecheck
+ self.fail('regexes list must all be the same type.', message)
+
+ if regex_type is bytes and isinstance(actual_str, six.text_type):
+ regexes = [regex.decode('utf-8') for regex in regexes]
+ regex_type = six.text_type
+ elif regex_type is six.text_type and isinstance(actual_str, bytes):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+ regex_type = bytes
+
+ if regex_type is six.text_type:
+ regex = u'(?:%s)' % u')|(?:'.join(regexes)
+ elif regex_type is bytes:
+ regex = b'(?:' + (b')|(?:'.join(regexes)) + b')'
+ else:
+ self.fail('Only know how to deal with unicode str or bytes regexes.',
+ message)
+
+ if not re.search(regex, actual_str, re.MULTILINE):
+ self.fail('"%s" does not contain any of these regexes: %s.' %
+ (actual_str, regexes), message)
+
+ def assertCommandSucceeds(self, command, regexes=(b'',), env=None,
+ close_fds=True, msg=None):
+ """Asserts that a shell command succeeds (i.e. exits with code 0).
+
+ Args:
+ command: List or string representing the command to run.
+ regexes: List of regular expression byte strings that match success.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after
+ forking.
+ msg: Optional message to report on failure.
+ """
+ (ret_code, err) = get_command_stderr(command, env, close_fds)
+
+ # We need bytes regexes here because `err` is bytes.
+ # Accommodate code which listed their output regexes w/o the b'' prefix by
+ # converting them to bytes for the user.
+ if isinstance(regexes[0], six.text_type):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+
+ command_string = get_command_string(command)
+ self.assertEqual(
+ ret_code, 0,
+ self._formatMessage(msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s' % (_quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err)))
+ )
+ self.assertRegexMatch(
+ err,
+ regexes,
+ message=self._formatMessage(
+ msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s which matches no regex in %s' % (
+ _quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err),
+ regexes)))
+
+ def assertCommandFails(self, command, regexes, env=None, close_fds=True,
+ msg=None):
+ """Asserts a shell command fails and the error matches a regex in a list.
+
+ Args:
+ command: List or string representing the command to run.
+ regexes: the list of regular expression strings.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after
+ forking.
+ msg: Optional message to report on failure.
+ """
+ (ret_code, err) = get_command_stderr(command, env, close_fds)
+
+ # We need bytes regexes here because `err` is bytes.
+ # Accommodate code which listed their output regexes w/o the b'' prefix by
+ # converting them to bytes for the user.
+ if isinstance(regexes[0], six.text_type):
+ regexes = [regex.encode('utf-8') for regex in regexes]
+
+ command_string = get_command_string(command)
+ self.assertNotEqual(
+ ret_code, 0,
+ self._formatMessage(msg, 'The following command succeeded '
+ 'while expected to fail:\n%s' %
+ _quote_long_string(command_string)))
+ self.assertRegexMatch(
+ err,
+ regexes,
+ message=self._formatMessage(
+ msg,
+ 'Running command\n'
+ '%s failed with error code %s and message\n'
+ '%s which matches no regex in %s' % (
+ _quote_long_string(command_string),
+ ret_code,
+ _quote_long_string(err),
+ regexes)))
+
+ class _AssertRaisesContext(object):
+
+ def __init__(self, expected_exception, test_case, test_func, msg=None):
+ self.expected_exception = expected_exception
+ self.test_case = test_case
+ self.test_func = test_func
+ self.msg = msg
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is None:
+ self.test_case.fail(self.expected_exception.__name__ + ' not raised',
+ self.msg)
+ if not issubclass(exc_type, self.expected_exception):
+ return False
+ self.test_func(exc_value)
+ return True
+
+ def assertRaisesWithPredicateMatch(self, expected_exception, predicate,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that exception is thrown and predicate(exception) is true.
+
+ Args:
+ expected_exception: Exception class expected to be raised.
+ predicate: Function of one argument that inspects the passed-in exception
+ and returns True (success) or False (please fail the test).
+ callable_obj: Function to be called.
+ *args: Extra args.
+ **kwargs: Extra keyword args.
+
+ Returns:
+ A context manager if callable_obj is None. Otherwise, None.
+
+ Raises:
+ self.failureException if callable_obj does not raise a matching exception.
+ """
+ def Check(err):
+ self.assertTrue(predicate(err),
+ '%r does not match predicate %r' % (err, predicate))
+
+ context = self._AssertRaisesContext(expected_exception, self, Check)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
+ def assertRaisesWithLiteralMatch(self, expected_exception,
+ expected_exception_message,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that the message in a raised exception equals the given string.
+
+ Unlike assertRaisesRegexp, this method takes a literal string, not
+ a regular expression.
+
+ with self.assertRaisesWithLiteralMatch(ExType, 'message'):
+ DoSomething()
+
+ Args:
+ expected_exception: Exception class expected to be raised.
+ expected_exception_message: String message expected in the raised
+ exception. For a raise exception e, expected_exception_message must
+ equal str(e).
+ callable_obj: Function to be called, or None to return a context.
+ *args: Extra args.
+ **kwargs: Extra kwargs.
+
+ Returns:
+ A context manager if callable_obj is None. Otherwise, None.
+
+ Raises:
+ self.failureException if callable_obj does not raise a matching exception.
+ """
+ def Check(err):
+ actual_exception_message = str(err)
+ self.assertTrue(expected_exception_message == actual_exception_message,
+ 'Exception message does not match.\n'
+ 'Expected: %r\n'
+ 'Actual: %r' % (expected_exception_message,
+ actual_exception_message))
+
+ context = self._AssertRaisesContext(expected_exception, self, Check)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
+ def assertContainsInOrder(self, strings, target, msg=None):
+ """Asserts that the strings provided are found in the target in order.
+
+ This may be useful for checking HTML output.
+
+ Args:
+ strings: A list of strings, such as [ 'fox', 'dog' ]
+ target: A target string in which to look for the strings, such as
+ 'The quick brown fox jumped over the lazy dog'.
+ msg: Optional message to report on failure.
+ """
+ if isinstance(strings, (bytes, unicode if str is bytes else str)):
+ strings = (strings,)
+
+ current_index = 0
+ last_string = None
+ for string in strings:
+ index = target.find(str(string), current_index)
+ if index == -1 and current_index == 0:
+ self.fail("Did not find '%s' in '%s'" %
+ (string, target), msg)
+ elif index == -1:
+ self.fail("Did not find '%s' after '%s' in '%s'" %
+ (string, last_string, target), msg)
+ last_string = string
+ current_index = index
+
+ def assertContainsSubsequence(self, container, subsequence, msg=None):
+ """Asserts that "container" contains "subsequence" as a subsequence.
+
+ Asserts that "container" contains all the elements of "subsequence", in
+ order, but possibly with other elements interspersed. For example, [1, 2, 3]
+ is a subsequence of [0, 0, 1, 2, 0, 3, 0] but not of [0, 0, 1, 3, 0, 2, 0].
+
+ Args:
+ container: the list we're testing for subsequence inclusion.
+ subsequence: the list we hope will be a subsequence of container.
+ msg: Optional message to report on failure.
+ """
+ first_nonmatching = None
+ reversed_container = list(reversed(container))
+ subsequence = list(subsequence)
+
+ for e in subsequence:
+ if e not in reversed_container:
+ first_nonmatching = e
+ break
+ while e != reversed_container.pop():
+ pass
+
+ if first_nonmatching is not None:
+ self.fail('%s not a subsequence of %s. First non-matching element: %s' %
+ (subsequence, container, first_nonmatching), msg)
+
+ def assertContainsExactSubsequence(self, container, subsequence, msg=None):
+ """Asserts that "container" contains "subsequence" as an exact subsequence.
+
+ Asserts that "container" contains all the elements of "subsequence", in
+ order, and without other elements interspersed. For example, [1, 2, 3] is an
+ exact subsequence of [0, 0, 1, 2, 3, 0] but not of [0, 0, 1, 2, 0, 3, 0].
+
+ Args:
+ container: the list we're testing for subsequence inclusion.
+ subsequence: the list we hope will be an exact subsequence of container.
+ msg: Optional message to report on failure.
+ """
+ container = list(container)
+ subsequence = list(subsequence)
+ longest_match = 0
+
+ for start in xrange(1 + len(container) - len(subsequence)):
+ if longest_match == len(subsequence):
+ break
+ index = 0
+ while (index < len(subsequence) and
+ subsequence[index] == container[start + index]):
+ index += 1
+ longest_match = max(longest_match, index)
+
+ if longest_match < len(subsequence):
+ self.fail('%s not an exact subsequence of %s. '
+ 'Longest matching prefix: %s' %
+ (subsequence, container, subsequence[:longest_match]), msg)
+
+ def assertTotallyOrdered(self, *groups, **kwargs):
+ """Asserts that total ordering has been implemented correctly.
+
+ For example, say you have a class A that compares only on its attribute x.
+ Comparators other than __lt__ are omitted for brevity.
+
+ class A(object):
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __hash__(self):
+ return hash(self.x)
+
+ def __lt__(self, other):
+ try:
+ return self.x < other.x
+ except AttributeError:
+ return NotImplemented
+
+ assertTotallyOrdered will check that instances can be ordered correctly.
+ For example,
+
+ self.assertTotallyOrdered(
+ [None], # None should come before everything else.
+ [1], # Integers sort earlier.
+ [A(1, 'a')],
+ [A(2, 'b')], # 2 is after 1.
+ [A(3, 'c'), A(3, 'd')], # The second argument is irrelevant.
+ [A(4, 'z')],
+ ['foo']) # Strings sort last.
+
+ Args:
+ *groups: A list of groups of elements. Each group of elements is a list
+ of objects that are equal. The elements in each group must be less
+ than the elements in the group after it. For example, these groups are
+ totally ordered: [None], [1], [2, 2], [3].
+ **kwargs: optional msg keyword argument can be passed.
+ """
+
+ def CheckOrder(small, big):
+ """Ensures small is ordered before big."""
+ self.assertFalse(small == big,
+ self._formatMessage(msg, '%r unexpectedly equals %r' %
+ (small, big)))
+ self.assertTrue(small != big,
+ self._formatMessage(msg, '%r unexpectedly equals %r' %
+ (small, big)))
+ self.assertLess(small, big, msg)
+ self.assertFalse(big < small,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (big, small)))
+ self.assertLessEqual(small, big, msg)
+ self.assertFalse(big <= small, self._formatMessage(
+ '%r unexpectedly less than or equal to %r' % (big, small), msg
+ ))
+ self.assertGreater(big, small, msg)
+ self.assertFalse(small > big,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (small, big)))
+ self.assertGreaterEqual(big, small)
+ self.assertFalse(small >= big, self._formatMessage(
+ msg,
+ '%r unexpectedly greater than or equal to %r' % (small, big)))
+
+ def CheckEqual(a, b):
+ """Ensures that a and b are equal."""
+ self.assertEqual(a, b, msg)
+ self.assertFalse(a != b,
+ self._formatMessage(msg, '%r unexpectedly unequals %r' %
+ (a, b)))
+ self.assertEqual(hash(a), hash(b), self._formatMessage(
+ msg,
+ 'hash %d of %r unexpectedly not equal to hash %d of %r' %
+ (hash(a), a, hash(b), b)))
+ self.assertFalse(a < b,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (a, b)))
+ self.assertFalse(b < a,
+ self._formatMessage(msg,
+ '%r unexpectedly less than %r' %
+ (b, a)))
+ self.assertLessEqual(a, b, msg)
+ self.assertLessEqual(b, a, msg)
+ self.assertFalse(a > b,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (a, b)))
+ self.assertFalse(b > a,
+ self._formatMessage(msg,
+ '%r unexpectedly greater than %r' %
+ (b, a)))
+ self.assertGreaterEqual(a, b, msg)
+ self.assertGreaterEqual(b, a, msg)
+
+ msg = kwargs.get('msg')
+
+ # For every combination of elements, check the order of every pair of
+ # elements.
+ for elements in itertools.product(*groups):
+ elements = list(elements)
+ for index, small in enumerate(elements[:-1]):
+ for big in elements[index + 1:]:
+ CheckOrder(small, big)
+
+ # Check that every element in each group is equal.
+ for group in groups:
+ for a in group:
+ CheckEqual(a, a)
+ for a, b in itertools.product(group, group):
+ CheckEqual(a, b)
+
+ def assertDictEqual(self, a, b, msg=None):
+ """Raises AssertionError if a and b are not equal dictionaries.
+
+ Args:
+ a: A dict, the expected value.
+ b: A dict, the actual value.
+ msg: An optional str, the associated message.
+
+ Raises:
+ AssertionError: if the dictionaries are not equal.
+ """
+ self.assertIsInstance(a, dict, self._formatMessage(
+ msg,
+ 'First argument is not a dictionary'
+ ))
+ self.assertIsInstance(b, dict, self._formatMessage(
+ msg,
+ 'Second argument is not a dictionary'
+ ))
+
+ def Sorted(list_of_items):
+ try:
+ return sorted(list_of_items) # In 3.3, unordered are possible.
+ except TypeError:
+ return list_of_items
+
+ if a == b:
+ return
+ a_items = Sorted(list(six.iteritems(a)))
+ b_items = Sorted(list(six.iteritems(b)))
+
+ unexpected = []
+ missing = []
+ different = []
+
+ safe_repr = unittest.util.safe_repr
+
+ def Repr(dikt):
+ """Deterministic repr for dict."""
+ # Sort the entries based on their repr, not based on their sort order,
+ # which will be non-deterministic across executions, for many types.
+ entries = sorted((safe_repr(k), safe_repr(v))
+ for k, v in six.iteritems(dikt))
+ return '{%s}' % (', '.join('%s: %s' % pair for pair in entries))
+
+ message = ['%s != %s%s' % (Repr(a), Repr(b), ' (%s)' % msg if msg else '')]
+
+ # The standard library default output confounds lexical difference with
+ # value difference; treat them separately.
+ for a_key, a_value in a_items:
+ if a_key not in b:
+ missing.append((a_key, a_value))
+ elif a_value != b[a_key]:
+ different.append((a_key, a_value, b[a_key]))
+
+ for b_key, b_value in b_items:
+ if b_key not in a:
+ unexpected.append((b_key, b_value))
+
+ if unexpected:
+ message.append(
+ 'Unexpected, but present entries:\n%s' % ''.join(
+ '%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in unexpected))
+
+ if different:
+ message.append(
+ 'repr() of differing entries:\n%s' % ''.join(
+ '%s: %s != %s\n' % (safe_repr(k), safe_repr(a_value),
+ safe_repr(b_value))
+ for k, a_value, b_value in different))
+
+ if missing:
+ message.append(
+ 'Missing entries:\n%s' % ''.join(
+ ('%s: %s\n' % (safe_repr(k), safe_repr(v)) for k, v in missing)))
+
+ raise self.failureException('\n'.join(message))
+
+ def assertUrlEqual(self, a, b, msg=None):
+ """Asserts that urls are equal, ignoring ordering of query params."""
+ parsed_a = urllib.parse.urlparse(a)
+ parsed_b = urllib.parse.urlparse(b)
+ self.assertEqual(parsed_a.scheme, parsed_b.scheme, msg)
+ self.assertEqual(parsed_a.netloc, parsed_b.netloc, msg)
+ self.assertEqual(parsed_a.path, parsed_b.path, msg)
+ self.assertEqual(parsed_a.fragment, parsed_b.fragment, msg)
+ self.assertEqual(sorted(parsed_a.params.split(';')),
+ sorted(parsed_b.params.split(';')), msg)
+ self.assertDictEqual(
+ urllib.parse.parse_qs(parsed_a.query, keep_blank_values=True),
+ urllib.parse.parse_qs(parsed_b.query, keep_blank_values=True), msg)
+
+ def assertSameStructure(self, a, b, aname='a', bname='b', msg=None):
+ """Asserts that two values contain the same structural content.
+
+ The two arguments should be data trees consisting of trees of dicts and
+ lists. They will be deeply compared by walking into the contents of dicts
+ and lists; other items will be compared using the == operator.
+ If the two structures differ in content, the failure message will indicate
+ the location within the structures where the first difference is found.
+ This may be helpful when comparing large structures.
+
+ Args:
+ a: The first structure to compare.
+ b: The second structure to compare.
+ aname: Variable name to use for the first structure in assertion messages.
+ bname: Variable name to use for the second structure.
+ msg: Additional text to include in the failure message.
+ """
+
+ # Accumulate all the problems found so we can report all of them at once
+ # rather than just stopping at the first
+ problems = []
+
+ _walk_structure_for_problems(a, b, aname, bname, problems)
+
+ # Avoid spamming the user toooo much
+ if self.maxDiff is not None:
+ max_problems_to_show = self.maxDiff // 80
+ if len(problems) > max_problems_to_show:
+ problems = problems[0:max_problems_to_show-1] + ['...']
+
+ if problems:
+ self.fail('; '.join(problems), msg)
+
+ def assertJsonEqual(self, first, second, msg=None):
+ """Asserts that the JSON objects defined in two strings are equal.
+
+ A summary of the differences will be included in the failure message
+ using assertSameStructure.
+
+ Args:
+ first: A string contining JSON to decode and compare to second.
+ second: A string contining JSON to decode and compare to first.
+ msg: Additional text to include in the failure message.
+ """
+ try:
+ first_structured = json.loads(first)
+ except ValueError as e:
+ raise ValueError(self._formatMessage(
+ msg,
+ 'could not decode first JSON value %s: %s' % (first, e)))
+
+ try:
+ second_structured = json.loads(second)
+ except ValueError as e:
+ raise ValueError(self._formatMessage(
+ msg,
+ 'could not decode second JSON value %s: %s' % (second, e)))
+
+ self.assertSameStructure(first_structured, second_structured,
+ aname='first', bname='second', msg=msg)
+
+ def _getAssertEqualityFunc(self, first, second):
+ try:
+ return super(TestCase, self)._getAssertEqualityFunc(first, second)
+ except AttributeError:
+ # This is a workaround if unittest.TestCase.__init__ was never run.
+ # It usually means that somebody created a subclass just for the
+ # assertions and has overridden __init__. "assertTrue" is a safe
+ # value that will not make __init__ raise a ValueError.
+ test_method = getattr(self, '_testMethodName', 'assertTrue')
+ super(TestCase, self).__init__(test_method)
+
+ return super(TestCase, self)._getAssertEqualityFunc(first, second)
+
+ def fail(self, msg=None, prefix=None):
+ """Fail immediately with the given message, optionally prefixed."""
+ return super(TestCase, self).fail(self._formatMessage(prefix, msg))
+
+
+def _sorted_list_difference(expected, actual):
+ """Finds elements in only one or the other of two, sorted input lists.
+
+ Returns a two-element tuple of lists. The first list contains those
+ elements in the "expected" list but not in the "actual" list, and the
+ second contains those elements in the "actual" list but not in the
+ "expected" list. Duplicate elements in either input list are ignored.
+
+ Args:
+ expected: The list we expected.
+ actual: The list we actualy got.
+ Returns:
+ (missing, unexpected)
+ missing: items in expected that are not in actual.
+ unexpected: items in actual that are not in expected.
+ """
+ i = j = 0
+ missing = []
+ unexpected = []
+ while True:
+ try:
+ e = expected[i]
+ a = actual[j]
+ if e < a:
+ missing.append(e)
+ i += 1
+ while expected[i] == e:
+ i += 1
+ elif e > a:
+ unexpected.append(a)
+ j += 1
+ while actual[j] == a:
+ j += 1
+ else:
+ i += 1
+ try:
+ while expected[i] == e:
+ i += 1
+ finally:
+ j += 1
+ while actual[j] == a:
+ j += 1
+ except IndexError:
+ missing.extend(expected[i:])
+ unexpected.extend(actual[j:])
+ break
+ return missing, unexpected
+
+
+def _walk_structure_for_problems(a, b, aname, bname, problem_list):
+ """The recursive comparison behind assertSameStructure."""
+ if type(a) != type(b) and not ( # pylint: disable=unidiomatic-typecheck
+ isinstance(a, six.integer_types) and isinstance(b, six.integer_types)):
+ # We do not distinguish between int and long types as 99.99% of Python 2
+ # code should never care. They collapse into a single type in Python 3.
+ problem_list.append('%s is a %r but %s is a %r' %
+ (aname, type(a), bname, type(b)))
+ # If they have different types there's no point continuing
+ return
+
+ if isinstance(a, collections.Mapping):
+ for k in a:
+ if k in b:
+ _walk_structure_for_problems(
+ a[k], b[k], '%s[%r]' % (aname, k), '%s[%r]' % (bname, k),
+ problem_list)
+ else:
+ problem_list.append(
+ "%s has [%r] with value %r but it's missing in %s" %
+ (aname, k, a[k], bname))
+ for k in b:
+ if k not in a:
+ problem_list.append(
+ '%s lacks [%r] but %s has it with value %r' %
+ (aname, k, bname, b[k]))
+
+ # Strings/bytes are Sequences but we'll just do those with regular !=
+ elif (isinstance(a, collections.Sequence) and
+ not isinstance(a, _TEXT_OR_BINARY_TYPES)):
+ minlen = min(len(a), len(b))
+ for i in xrange(minlen):
+ _walk_structure_for_problems(
+ a[i], b[i], '%s[%d]' % (aname, i), '%s[%d]' % (bname, i),
+ problem_list)
+ for i in xrange(minlen, len(a)):
+ problem_list.append('%s has [%i] with value %r but %s does not' %
+ (aname, i, a[i], bname))
+ for i in xrange(minlen, len(b)):
+ problem_list.append('%s lacks [%i] but %s has it with value %r' %
+ (aname, i, bname, b[i]))
+
+ else:
+ if a != b:
+ problem_list.append('%s is %r but %s is %r' % (aname, a, bname, b))
+
+
+def get_command_string(command):
+ """Returns an escaped string that can be used as a shell command.
+
+ Args:
+ command: List or string representing the command to run.
+ Returns:
+ A string suitable for use as a shell command.
+ """
+ if isinstance(command, six.string_types):
+ return command
+ else:
+ if os.name == 'nt':
+ return ' '.join(command)
+ else:
+ # The following is identical to Python 3's shlex.quote function.
+ command_string = ''
+ for word in command:
+ # Single quote word, and replace each ' in word with '"'"'
+ command_string += "'" + word.replace("'", "'\"'\"'") + "' "
+ return command_string[:-1]
+
+
+def get_command_stderr(command, env=None, close_fds=True):
+ """Runs the given shell command and returns a tuple.
+
+ Args:
+ command: List or string representing the command to run.
+ env: Dictionary of environment variable settings. If None, no environment
+ variables will be set for the child process. This is to make tests
+ more hermetic. NOTE: this behavior is different than the standard
+ subprocess module.
+ close_fds: Whether or not to close all open fd's in the child after forking.
+ On Windows, this is ignored and close_fds is always False.
+
+ Returns:
+ Tuple of (exit status, text printed to stdout and stderr by the command).
+ """
+ if env is None: env = {}
+ if os.name == 'nt':
+ # Windows does not support setting close_fds to True while also redirecting
+ # standard handles.
+ close_fds = False
+
+ use_shell = isinstance(command, six.string_types)
+ process = subprocess.Popen(
+ command,
+ close_fds=close_fds,
+ env=env,
+ shell=use_shell,
+ stderr=subprocess.STDOUT,
+ stdout=subprocess.PIPE)
+ output = process.communicate()[0]
+ exit_status = process.wait()
+ return (exit_status, output)
+
+
+def _quote_long_string(s):
+ """Quotes a potentially multi-line string to make the start and end obvious.
+
+ Args:
+ s: A string.
+
+ Returns:
+ The quoted string.
+ """
+ if isinstance(s, (bytes, bytearray)):
+ try:
+ s = s.decode('utf-8')
+ except UnicodeDecodeError:
+ s = str(s)
+ return ('8<-----------\n' +
+ s + '\n' +
+ '----------->8\n')
+
+
+class _TestProgramManualRun(unittest.TestProgram):
+ """A TestProgram which runs the tests manually."""
+
+ def runTests(self, do_run=False):
+ """Runs the tests."""
+ if do_run:
+ unittest.TestProgram.runTests(self)
+
+
+def print_python_version():
+ # Having this in the test output logs by default helps debugging when all
+ # you've got is the log and no other idea of which Python was used.
+ sys.stderr.write('Running tests under Python {0[0]}.{0[1]}.{0[2]}: '
+ '{1}\n'.format(
+ sys.version_info,
+ sys.executable if sys.executable else 'embedded.'))
+
+
+def main(*args, **kwargs):
+ """Executes a set of Python unit tests.
+
+ Usually this function is called without arguments, so the
+ unittest.TestProgram instance will get created with the default settings,
+ so it will run all test methods of all TestCase classes in the __main__
+ module.
+
+ Args:
+ *args: Positional arguments passed through to unittest.TestProgram.__init__.
+ **kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ print_python_version()
+ _run_in_app(run_tests, args, kwargs)
+
+
+def _is_in_app_main():
+ """Returns True iff app.run is active."""
+ f = sys._getframe().f_back # pylint: disable=protected-access
+ while f:
+ if f.f_code == six.get_function_code(app.run):
+ return True
+ f = f.f_back
+ return False
+
+
+class _SavedFlag(object):
+ """Helper class for saving and restoring a flag value."""
+
+ def __init__(self, flag):
+ self.flag = flag
+ self.value = flag.value
+ self.present = flag.present
+
+ def restore_flag(self):
+ self.flag.value = self.value
+ self.flag.present = self.present
+
+
+def _register_sigterm_with_faulthandler():
+ """Have faulthandler dump stacks on SIGTERM. Useful to diagnose timeouts."""
+ if faulthandler and getattr(faulthandler, 'register', None):
+ # faulthandler.register is not avaiable on Windows.
+ # faulthandler.enable() is already called by app.run.
+ try:
+ faulthandler.register(signal.SIGTERM, chain=True)
+ except Exception as e: # pylint: disable=broad-except
+ sys.stderr.write('faulthandler.register(SIGTERM) failed '
+ '%r; ignoring.\n' % e)
+
+
+def _run_in_app(function, args, kwargs):
+ """Executes a set of Python unit tests, ensuring app.run.
+
+ This is a private function, users should call absltest.main().
+
+ _run_in_app calculates argv to be the command-line arguments of this program
+ (without the flags), sets the default of FLAGS.alsologtostderr to True,
+ then it calls function(argv, args, kwargs), making sure that `function'
+ will get called within app.run(). _run_in_app does this by checking whether
+ it is called by app.run(), or by calling app.run() explicitly.
+
+ The reason why app.run has to be ensured is to make sure that
+ flags are parsed and stripped properly, and other initializations done by
+ the app module are also carried out, no matter if absltest.run() is called
+ from within or outside app.run().
+
+ If _run_in_app is called from within app.run(), then it will reparse
+ sys.argv and pass the result without command-line flags into the argv
+ argument of `function'. The reason why this parsing is needed is that
+ __main__.main() calls absltest.main() without passing its argv. So the
+ only way _run_in_app could get to know the argv without the flags is that
+ it reparses sys.argv.
+
+ _run_in_app changes the default of FLAGS.alsologtostderr to True so that the
+ test program's stderr will contain all the log messages unless otherwise
+ specified on the command-line. This overrides any explicit assignment to
+ FLAGS.alsologtostderr by the test program prior to the call to _run_in_app()
+ (e.g. in __main__.main).
+
+ Please note that _run_in_app (and the function it calls) is allowed to make
+ changes to kwargs.
+
+ Args:
+ function: absltest.run_tests or a similar function. It will be called as
+ function(argv, args, kwargs) where argv is a list containing the
+ elements of sys.argv without the command-line flags.
+ args: Positional arguments passed through to unittest.TestProgram.__init__.
+ kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ if _is_in_app_main():
+ _register_sigterm_with_faulthandler()
+
+ # Save command-line flags so the side effects of FLAGS(sys.argv) can be
+ # undone.
+ flag_objects = (FLAGS[name] for name in FLAGS)
+ saved_flags = dict((f.name, _SavedFlag(f)) for f in flag_objects)
+
+ # Change the default of alsologtostderr from False to True, so the test
+ # programs's stderr will contain all the log messages.
+ # If --alsologtostderr=false is specified in the command-line, or user
+ # has called FLAGS.alsologtostderr = False before, then the value is kept
+ # False.
+ FLAGS.set_default('alsologtostderr', True)
+ # Remove it from saved flags so it doesn't get restored later.
+ del saved_flags['alsologtostderr']
+
+ # The call FLAGS(sys.argv) parses sys.argv, returns the arguments
+ # without the flags, and -- as a side effect -- modifies flag values in
+ # FLAGS. We don't want the side effect, because we don't want to
+ # override flag changes the program did (e.g. in __main__.main)
+ # after the command-line has been parsed. So we have the for loop below
+ # to change back flags to their old values.
+ argv = FLAGS(sys.argv)
+ for saved_flag in six.itervalues(saved_flags):
+ saved_flag.restore_flag()
+
+
+ function(argv, args, kwargs)
+ else:
+ # Send logging to stderr. Use --alsologtostderr instead of --logtostderr
+ # in case tests are reading their own logs.
+ FLAGS.set_default('alsologtostderr', True)
+
+ def main_function(argv):
+ _register_sigterm_with_faulthandler()
+ function(argv, args, kwargs)
+
+ app.run(main=main_function)
+
+
+def _is_suspicious_attribute(testCaseClass, name):
+ """Returns True if an attribute is a method named like a test method."""
+ if name.startswith('Test') and len(name) > 4 and name[4].isupper():
+ attr = getattr(testCaseClass, name)
+ if inspect.isfunction(attr) or inspect.ismethod(attr):
+ args = inspect.getargspec(attr)
+ return (len(args.args) == 1 and args.args[0] == 'self'
+ and args.varargs is None and args.keywords is None)
+ return False
+
+
+class TestLoader(unittest.TestLoader):
+ """A test loader which supports common test features.
+
+ Supported features include:
+ * Banning untested methods with test-like names: methods attached to this
+ testCase with names starting with `Test` are ignored by the test runner,
+ and often represent mistakenly-omitted test cases. This loader will raise
+ a TypeError when attempting to load a TestCase with such methods.
+ * Randomization of test case execution order (optional).
+ """
+
+ _ERROR_MSG = textwrap.dedent("""Method '%s' is named like a test case but
+ is not one. This is often a bug. If you want it to be a test method,
+ name it with 'test' in lowercase. If not, rename the method to not begin
+ with 'Test'.""")
+
+ def __init__(self, *args, **kwds):
+ super(TestLoader, self).__init__(*args, **kwds)
+ seed = _get_default_randomize_ordering_seed()
+ if seed:
+ self._seed = seed
+ self._random = random.Random(self._seed)
+ else:
+ self._seed = None
+ self._random = None
+
+ def getTestCaseNames(self, testCaseClass): # pylint:disable=invalid-name
+ """Validates and returns a (possibly randomized) list of test case names."""
+ for name in dir(testCaseClass):
+ if _is_suspicious_attribute(testCaseClass, name):
+ raise TypeError(TestLoader._ERROR_MSG % name)
+ names = super(TestLoader, self).getTestCaseNames(testCaseClass)
+ if self._seed is not None:
+ logging.info('Randomizing test order with seed: %d', self._seed)
+ logging.info('To reproduce this order, re-run with '
+ '--test_randomize_ordering_seed=%d', self._seed)
+ self._random.shuffle(names)
+ return names
+
+
+def get_default_xml_output_filename():
+ if os.environ.get('XML_OUTPUT_FILE'):
+ return os.environ['XML_OUTPUT_FILE']
+ elif os.environ.get('RUNNING_UNDER_TEST_DAEMON'):
+ return os.path.join(os.path.dirname(FLAGS.test_tmpdir), 'test_detail.xml')
+ elif os.environ.get('TEST_XMLOUTPUTDIR'):
+ return os.path.join(
+ os.environ['TEST_XMLOUTPUTDIR'],
+ os.path.splitext(os.path.basename(sys.argv[0]))[0] + '.xml')
+
+
+def _setup_filtering(argv):
+ """Implements the bazel test filtering protocol.
+
+ The following environment variable is used in this method:
+
+ TESTBRIDGE_TEST_ONLY: string, if set, is forwarded to the unittest
+ framework to use as a test filter. Its value is split with shlex
+ before being passed as positional arguments on argv.
+
+ Args:
+ argv: the argv to mutate in-place.
+ """
+ test_filter = os.environ.get('TESTBRIDGE_TEST_ONLY')
+ if argv is None or not test_filter:
+ return
+
+ argv[1:1] = shlex.split(test_filter)
+
+
+def _setup_sharding(custom_loader=None):
+ """Implements the bazel sharding protocol.
+
+ The following environment variables are used in this method:
+
+ TEST_SHARD_STATUS_FILE: string, if set, points to a file. We write a blank
+ file to tell the test runner that this test implements the test sharding
+ protocol.
+
+ TEST_TOTAL_SHARDS: int, if set, sharding is requested.
+
+ TEST_SHARD_INDEX: int, must be set if TEST_TOTAL_SHARDS is set. Specifies
+ the shard index for this instance of the test process. Must satisfy:
+ 0 <= TEST_SHARD_INDEX < TEST_TOTAL_SHARDS.
+
+ Args:
+ custom_loader: A TestLoader to be made sharded.
+
+ Returns:
+ The test loader for shard-filtering or the standard test loader, depending
+ on the sharding environment variables.
+ """
+
+ # It may be useful to write the shard file even if the other sharding
+ # environment variables are not set. Test runners may use this functionality
+ # to query whether a test binary implements the test sharding protocol.
+ if 'TEST_SHARD_STATUS_FILE' in os.environ:
+ try:
+ f = None
+ try:
+ f = open(os.environ['TEST_SHARD_STATUS_FILE'], 'w')
+ f.write('')
+ except IOError:
+ sys.stderr.write('Error opening TEST_SHARD_STATUS_FILE (%s). Exiting.'
+ % os.environ['TEST_SHARD_STATUS_FILE'])
+ sys.exit(1)
+ finally:
+ if f is not None: f.close()
+
+ base_loader = custom_loader or TestLoader()
+ if 'TEST_TOTAL_SHARDS' not in os.environ:
+ # Not using sharding, use the expected test loader.
+ return base_loader
+
+ total_shards = int(os.environ['TEST_TOTAL_SHARDS'])
+ shard_index = int(os.environ['TEST_SHARD_INDEX'])
+
+ if shard_index < 0 or shard_index >= total_shards:
+ sys.stderr.write('ERROR: Bad sharding values. index=%d, total=%d\n' %
+ (shard_index, total_shards))
+ sys.exit(1)
+
+ # Replace the original getTestCaseNames with one that returns
+ # the test case names for this shard.
+ delegate_get_names = base_loader.getTestCaseNames
+
+ bucket_iterator = itertools.cycle(xrange(total_shards))
+
+ def getShardedTestCaseNames(testCaseClass):
+ filtered_names = []
+ for testcase in sorted(delegate_get_names(testCaseClass)):
+ bucket = next(bucket_iterator)
+ if bucket == shard_index:
+ filtered_names.append(testcase)
+ return filtered_names
+
+ base_loader.getTestCaseNames = getShardedTestCaseNames
+ return base_loader
+
+
+def _run_and_get_tests_result(argv, args, kwargs, xml_test_runner_class):
+ """Executes a set of Python unit tests and returns the result."""
+
+ # Set up test filtering if requested in environment.
+ _setup_filtering(argv)
+
+ # Shard the (default or custom) loader if sharding is turned on.
+ kwargs['testLoader'] = _setup_sharding(kwargs.get('testLoader', None))
+
+ # XML file name is based upon (sorted by priority):
+ # --xml_output_file flag, XML_OUTPUT_FILE variable,
+ # TEST_XMLOUTPUTDIR variable or RUNNING_UNDER_TEST_DAEMON variable.
+ if not FLAGS.xml_output_file:
+ FLAGS.xml_output_file = get_default_xml_output_filename()
+ xml_output_file = FLAGS.xml_output_file
+
+ xml_output = None
+ if xml_output_file:
+ xml_output_dir = os.path.dirname(xml_output_file)
+ if xml_output_dir and not os.path.isdir(xml_output_dir):
+ try:
+ os.makedirs(xml_output_dir)
+ except OSError as e:
+ # File exists error can occur with concurrent tests
+ if e.errno != errno.EEXIST:
+ raise
+ if sys.version_info.major == 2:
+ xml_output = open(xml_output_file, 'w')
+ else:
+ xml_output = open(xml_output_file, 'w', encoding='utf-8')
+ # We can reuse testRunner if it supports XML output (e. g. by inheriting
+ # from xml_reporter.TextAndXMLTestRunner). Otherwise we need to use
+ # xml_reporter.TextAndXMLTestRunner.
+ if (kwargs.get('testRunner') is not None
+ and not hasattr(kwargs['testRunner'], 'set_default_xml_stream')):
+ sys.stderr.write('WARNING: XML_OUTPUT_FILE or --xml_output_file setting '
+ 'overrides testRunner=%r setting (possibly from --pdb)'
+ % (kwargs['testRunner']))
+ # Passing a class object here allows TestProgram to initialize
+ # instances based on its kwargs and/or parsed command-line args.
+ kwargs['testRunner'] = xml_test_runner_class
+ if kwargs.get('testRunner') is None:
+ kwargs['testRunner'] = xml_test_runner_class
+ kwargs['testRunner'].set_default_xml_stream(xml_output)
+
+ # Make sure tmpdir exists.
+ if not os.path.isdir(FLAGS.test_tmpdir):
+ try:
+ os.makedirs(FLAGS.test_tmpdir)
+ except OSError as e:
+ # Concurrent test might have created the directory.
+ if e.errno != errno.EEXIST:
+ raise
+
+ # Let unittest.TestProgram.__init__ do its own argv parsing, e.g. for '-v',
+ # on argv, which is sys.argv without the command-line flags.
+ kwargs.setdefault('argv', argv)
+
+ try:
+ test_program = unittest.TestProgram(*args, **kwargs)
+ return test_program.result
+ finally:
+ if xml_output:
+ xml_output.close()
+
+
+def run_tests(argv, args, kwargs):
+ """Executes a set of Python unit tests.
+
+ Most users should call absltest.main() instead of run_tests.
+
+ Please note that run_tests should be called from app.run.
+ Calling absltest.main() would ensure that.
+
+ Please note that run_tests is allowed to make changes to kwargs.
+
+ Args:
+ argv: sys.argv with the command-line flags removed from the front, i.e. the
+ argv with which app.run() has called __main__.main.
+ args: Positional arguments passed through to unittest.TestProgram.__init__.
+ kwargs: Keyword arguments passed through to unittest.TestProgram.__init__.
+ """
+ result = _run_and_get_tests_result(
+ argv, args, kwargs, xml_reporter.TextAndXMLTestRunner)
+ sys.exit(not result.wasSuccessful())