aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/nest_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/util/nest_test.py')
-rw-r--r--tensorflow/python/util/nest_test.py68
1 files changed, 51 insertions, 17 deletions
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2f12b25354..26c6ea4b01 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import time
+from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -33,7 +34,22 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest
-class NestTest(test.TestCase):
+class _CustomMapping(collections.Mapping):
+
+ def __init__(self, *args, **kwargs):
+ self._wrapped = dict(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return self._wrapped[key]
+
+ def __iter__(self):
+ return iter(self._wrapped)
+
+ def __len__(self):
+ return len(self._wrapped)
+
+
+class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
@@ -72,26 +88,32 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
- def testFlattenDictOrder(self):
+ def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
+ ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
- def testPackDictOrder(self):
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
+ def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
+ custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
- ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
+ custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
+ self.assertIsInstance(custom_reconstruction, mapping_type)
+ self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
- collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
- ordered_reconstruction)
+ mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
+ custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
@@ -101,8 +123,10 @@ class NestTest(test.TestCase):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
- NestTest.Abc(3, 4),
- {
+ NestTest.Abc(3, 4), {
+ "d": _CustomMapping({
+ 41: 4
+ }),
"c": [
1,
collections.OrderedDict([
@@ -111,17 +135,19 @@ class NestTest(test.TestCase):
]),
],
"b": 5
- },
- 17
+ }, 17
]
flattened = nest.flatten(mess)
- self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
+ self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
structure_of_mess = [
14,
NestTest.Abc("a", True),
{
+ "d": _CustomMapping({
+ 41: 42
+ }),
"c": [
0,
collections.OrderedDict([
@@ -142,6 +168,10 @@ class NestTest(test.TestCase):
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
+ unflattened_custom_mapping = unflattened[2]["d"]
+ self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
+ self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
+
def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
@@ -179,19 +209,23 @@ class NestTest(test.TestCase):
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
- def testFlattenDictItems(self):
- dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
+ @parameterized.parameters({"mapping_type": _CustomMapping},
+ {"mapping_type": dict})
+ def testFlattenDictItems(self, mapping_type):
+ dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)
with self.assertRaises(TypeError):
nest.flatten_dict_items(4)
- bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))}
+ bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)
- another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))}
+ another_bad_dictionary = mapping_type({
+ (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
+ })
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)