diff options
Diffstat (limited to 'tensorflow/python/util/nest_test.py')
-rw-r--r-- | tensorflow/python/util/nest_test.py | 68 |
1 files changed, 51 insertions, 17 deletions
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index 2f12b25354..26c6ea4b01 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import time +from absl.testing import parameterized import numpy as np from six.moves import xrange # pylint: disable=redefined-builtin @@ -33,7 +34,22 @@ from tensorflow.python.platform import test from tensorflow.python.util import nest -class NestTest(test.TestCase): +class _CustomMapping(collections.Mapping): + + def __init__(self, *args, **kwargs): + self._wrapped = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._wrapped[key] + + def __iter__(self): + return iter(self._wrapped) + + def __len__(self): + return len(self._wrapped) + + +class NestTest(parameterized.TestCase, test.TestCase): PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name @@ -72,26 +88,32 @@ class NestTest(test.TestCase): with self.assertRaises(ValueError): nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) @test_util.assert_no_new_pyobjects_executing_eagerly - def testFlattenDictOrder(self): + def testFlattenDictOrder(self, mapping_type): """`flatten` orders dicts by key, including OrderedDicts.""" - ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) + ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) plain = {"d": 3, "b": 1, "a": 0, "c": 2} ordered_flat = nest.flatten(ordered) plain_flat = nest.flatten(plain) self.assertEqual([0, 1, 2, 3], ordered_flat) self.assertEqual([0, 1, 2, 3], plain_flat) - def testPackDictOrder(self): + @parameterized.parameters({"mapping_type": collections.OrderedDict}, + {"mapping_type": _CustomMapping}) + def testPackDictOrder(self, mapping_type): """Packing orders dicts by key, including OrderedDicts.""" - ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) + custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) plain = {"d": 0, "b": 0, "a": 0, "c": 0} seq = [0, 1, 2, 3] - ordered_reconstruction = nest.pack_sequence_as(ordered, seq) + custom_reconstruction = nest.pack_sequence_as(custom, seq) plain_reconstruction = nest.pack_sequence_as(plain, seq) + self.assertIsInstance(custom_reconstruction, mapping_type) + self.assertIsInstance(plain_reconstruction, dict) self.assertEqual( - collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), - ordered_reconstruction) + mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), + custom_reconstruction) self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name @@ -101,8 +123,10 @@ class NestTest(test.TestCase): # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. mess = [ "z", - NestTest.Abc(3, 4), - { + NestTest.Abc(3, 4), { + "d": _CustomMapping({ + 41: 4 + }), "c": [ 1, collections.OrderedDict([ @@ -111,17 +135,19 @@ class NestTest(test.TestCase): ]), ], "b": 5 - }, - 17 + }, 17 ] flattened = nest.flatten(mess) - self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17]) + self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) structure_of_mess = [ 14, NestTest.Abc("a", True), { + "d": _CustomMapping({ + 41: 42 + }), "c": [ 0, collections.OrderedDict([ @@ -142,6 +168,10 @@ class NestTest(test.TestCase): self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) + unflattened_custom_mapping = unflattened[2]["d"] + self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) + self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) + def testFlatten_numpyIsNotFlattened(self): structure = np.array([1, 2, 3]) flattened = nest.flatten(structure) @@ -179,19 +209,23 @@ class NestTest(test.TestCase): self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) self.assertFalse(nest.is_sequence(np.ones((4, 5)))) - def testFlattenDictItems(self): - dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} + @parameterized.parameters({"mapping_type": _CustomMapping}, + {"mapping_type": dict}) + def testFlattenDictItems(self, mapping_type): + dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) flat = {4: "a", 5: "b", 6: "c", 8: "d"} self.assertEqual(nest.flatten_dict_items(dictionary), flat) with self.assertRaises(TypeError): nest.flatten_dict_items(4) - bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))} + bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) with self.assertRaisesRegexp(ValueError, "not unique"): nest.flatten_dict_items(bad_dictionary) - another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))} + another_bad_dictionary = mapping_type({ + (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) + }) with self.assertRaisesRegexp( ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): nest.flatten_dict_items(another_bad_dictionary) |