aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/framework/test_util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/framework/test_util.py')
-rw-r--r--tensorflow/python/framework/test_util.py78
1 files changed, 76 insertions, 2 deletions
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 2bc2a189fa..fc47b1cca5 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -19,6 +19,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
+from collections import OrderedDict
import contextlib
import gc
import itertools
@@ -571,6 +573,78 @@ def assert_no_garbage_created(f):
return decorator
+def _combine_named_parameters(**kwargs):
+ """Generate combinations based on its keyword arguments.
+
+ Two sets of returned combinations can be concatenated using +. Their product
+ can be computed using `times()`.
+
+ Args:
+ **kwargs: keyword arguments of form `option=[possibilities, ...]`
+ or `option=the_only_possibility`.
+
+ Returns:
+ a list of dictionaries for each combination. Keys in the dictionaries are
+ the keyword argument names. Each key has one value - one of the
+ corresponding keyword argument values.
+ """
+ if not kwargs:
+ return [OrderedDict()]
+
+ sort_by_key = lambda k: k[0][0]
+ kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
+ first = list(kwargs.items())[0]
+
+ rest = dict(list(kwargs.items())[1:])
+ rest_combined = _combine_named_parameters(**rest)
+
+ key = first[0]
+ values = first[1]
+ if not isinstance(values, list):
+ values = [values]
+
+ combinations = [
+ OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
+ for v in values
+ for combined in rest_combined
+ ]
+ return combinations
+
+
+def generate_combinations_with_testcase_name(**kwargs):
+ """Generate combinations based on its keyword arguments using combine().
+
+ This function calls combine() and appends a testcase name to the list of
+ dictionaries returned. The 'testcase_name' key is a required for named
+ parameterized tests.
+
+ Args:
+ **kwargs: keyword arguments of form `option=[possibilities, ...]`
+ or `option=the_only_possibility`.
+
+ Returns:
+ a list of dictionaries for each combination. Keys in the dictionaries are
+ the keyword argument names. Each key has one value - one of the
+ corresponding keyword argument values.
+ """
+ combinations = _combine_named_parameters(**kwargs)
+ named_combinations = []
+ for combination in combinations:
+ assert isinstance(combination, OrderedDict)
+ name = "".join([
+ "_{}_{}".format(
+ "".join(filter(str.isalnum, key)),
+ "".join(filter(str.isalnum, str(value))))
+ for key, value in combination.items()
+ ])
+ named_combinations.append(
+ OrderedDict(
+ list(combination.items()) + [("testcase_name",
+ "_test{}".format(name))]))
+
+ return named_combinations
+
+
def run_all_in_graph_and_eager_modes(cls):
"""Execute all test methods in the given class with and without eager."""
base_decorator = run_in_graph_and_eager_modes
@@ -1227,8 +1301,8 @@ class TensorFlowTestCase(googletest.TestCase):
a = a._asdict()
if hasattr(b, "_asdict"):
b = b._asdict()
- a_is_dict = isinstance(a, dict)
- if a_is_dict != isinstance(b, dict):
+ a_is_dict = isinstance(a, collections.Mapping)
+ if a_is_dict != isinstance(b, collections.Mapping):
raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
(path_str, path_str, msg))
if a_is_dict: