aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-07-24 11:09:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-24 11:13:21 -0700
commit4883a912780ab5783db654bebc6fc3ac25b63d74 (patch)
treefbae578e9f7e952971b027cf05b2ec1a49230785 /tensorflow/python/util
parenta45ffbd9b5c7d8fdaae6e41432f916639bdbe305 (diff)
Add support for custom mapping types to util.nest
Analagous to the existing support for custom collections.Sequence types. They need to be constructable with the same arguments as the base type for pack_sequence_as to work. Leaves PyDict_* calls for dict subclasses, but adds more general (and likely much slower) fallbacks for instances of collections.Mapping which are not dict subclasses. My hope is that this support will be enough so I can use a wrapper around dicts which does not inherit from dict in __setattr__ tracking (some tests failed without it). Inheriting from dict and properly shadowing a real dict seems impossible with CPython (since to shadow without synchronization issues, the wrapper needs to respond to updates to the original dict, but to work with e.g. {}.update(dict_subclass) the wrapper's C storage needs to also be updated). PiperOrigin-RevId: 205858082
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/nest_test.py68
-rw-r--r--tensorflow/python/util/util.cc299
-rw-r--r--tensorflow/python/util/util.h4
-rw-r--r--tensorflow/python/util/util.i3
5 files changed, 283 insertions, 102 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index d63f59a8c8..5aac559b9b 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -73,7 +73,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, dict):
+ if isinstance(instance, (dict, _collections.Mapping)):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +89,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, dict):
+ if isinstance(iterable, (dict, _collections.Mapping)):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -215,7 +215,7 @@ def flatten_dict_items(dictionary):
ValueError: If any key and value have not the same structure, or if keys are
not unique.
"""
- if not isinstance(dictionary, dict):
+ if not isinstance(dictionary, (dict, _collections.Mapping)):
raise TypeError("input must be a dictionary")
flat_dictionary = {}
for i, v in _six.iteritems(dictionary):
@@ -455,7 +455,7 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
"structure has length %s, while shallow structure has length %s."
% (len(input_tree), len(shallow_tree)))
- if check_types and isinstance(shallow_tree, dict):
+ if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
if set(input_tree) != set(shallow_tree):
raise ValueError(
"The two structures don't have the same keys. Input "
@@ -716,7 +716,7 @@ def yield_flat_paths(nest):
# The _maybe_add_final_path_element function is used below in order to avoid
# adding trailing slashes when the sub-element recursed into is a leaf.
- if isinstance(nest, dict):
+ if isinstance(nest, (dict, _collections.Mapping)):
for key in _sorted(nest):
value = nest[key]
for sub_path in yield_flat_paths(value):
@@ -760,3 +760,4 @@ def flatten_with_joined_string_paths(structure, separator="/"):
_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
+_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2f12b25354..26c6ea4b01 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import time
+from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -33,7 +34,22 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest
-class NestTest(test.TestCase):
+class _CustomMapping(collections.Mapping):
+
+ def __init__(self, *args, **kwargs):
+ self._wrapped = dict(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return self._wrapped[key]
+
+ def __iter__(self):
+ return iter(self._wrapped)
+
+ def __len__(self):
+ return len(self._wrapped)
+
+
+class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
@@ -72,26 +88,32 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
- def testFlattenDictOrder(self):
+ def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
+ ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
- def testPackDictOrder(self):
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
+ def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
+ custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
- ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
+ custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
+ self.assertIsInstance(custom_reconstruction, mapping_type)
+ self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
- collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
- ordered_reconstruction)
+ mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
+ custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
@@ -101,8 +123,10 @@ class NestTest(test.TestCase):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
- NestTest.Abc(3, 4),
- {
+ NestTest.Abc(3, 4), {
+ "d": _CustomMapping({
+ 41: 4
+ }),
"c": [
1,
collections.OrderedDict([
@@ -111,17 +135,19 @@ class NestTest(test.TestCase):
]),
],
"b": 5
- },
- 17
+ }, 17
]
flattened = nest.flatten(mess)
- self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
+ self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
structure_of_mess = [
14,
NestTest.Abc("a", True),
{
+ "d": _CustomMapping({
+ 41: 42
+ }),
"c": [
0,
collections.OrderedDict([
@@ -142,6 +168,10 @@ class NestTest(test.TestCase):
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
+ unflattened_custom_mapping = unflattened[2]["d"]
+ self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
+ self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
+
def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
@@ -179,19 +209,23 @@ class NestTest(test.TestCase):
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
- def testFlattenDictItems(self):
- dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
+ @parameterized.parameters({"mapping_type": _CustomMapping},
+ {"mapping_type": dict})
+ def testFlattenDictItems(self, mapping_type):
+ dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)
with self.assertRaises(TypeError):
nest.flatten_dict_items(4)
- bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))}
+ bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)
- another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))}
+ another_bad_dictionary = mapping_type({
+ (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
+ })
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 366f8a0deb..f9e0b7e4d2 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -31,6 +31,8 @@ namespace {
// Type object for collections.Sequence. This is set by RegisterSequenceClass.
PyObject* CollectionsSequenceType = nullptr;
+// Type object for collections.Mapping, set by RegisterMappingClass.
+PyObject* CollectionsMappingType = nullptr;
PyTypeObject* SparseTensorValueType = nullptr;
const int kMaxItemsInCache = 1024;
@@ -45,6 +47,23 @@ bool IsString(PyObject* o) {
PyUnicode_Check(o);
}
+// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
+// and while we're at it give them consistent behavior by making sure the
+// returned value is a list.
+//
+// As with PyMapping_Keys, returns a new reference.
+PyObject* MappingKeys(PyObject* o) {
+#if PY_MAJOR_VERSION >= 3
+ return PyMapping_Keys(o);
+#else
+ static char key_method_name[] = "keys";
+ Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
+ return PySequence_Fast(
+ raw_result.get(),
+ "The '.keys()' method of a custom mapping returned a non-sequence.");
+#endif
+}
+
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
@@ -85,6 +104,119 @@ string PyObjectToString(PyObject* o) {
}
}
+class CachedTypeCheck {
+ public:
+ explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
+ : ternary_predicate_(std::move(ternary_predicate)) {}
+
+ ~CachedTypeCheck() {
+ mutex_lock l(type_to_sequence_map_mu_);
+ for (const auto& pair : type_to_sequence_map_) {
+ Py_DECREF(pair.first);
+ }
+ }
+
+ // Caches successful executions of the one-argument (PyObject*) callable
+ // "ternary_predicate" based on the type of "o". -1 from the callable
+ // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
+ // does not match the predicate, and 1 indicates that it does. Used to avoid
+ // calling back into Python for expensive isinstance checks.
+ int CachedLookup(PyObject* o) {
+ // Try not to return to Python - see if the type has already been seen
+ // before.
+
+ auto* type = Py_TYPE(o);
+
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ auto it = type_to_sequence_map_.find(type);
+ if (it != type_to_sequence_map_.end()) {
+ return it->second;
+ }
+ }
+
+ int check_result = ternary_predicate_(o);
+
+ if (check_result == -1) {
+ return -1; // Type check error, not cached.
+ }
+
+ // NOTE: This is never decref'd as long as the object lives, which is likely
+ // forever, but we don't want the type to get deleted as long as it is in
+ // the map. This should not be too much of a leak, as there should only be a
+ // relatively small number of types in the map, and an even smaller number
+ // that are eligible for decref. As a precaution, we limit the size of the
+ // map to 1024.
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ if (type_to_sequence_map_.size() < kMaxItemsInCache) {
+ Py_INCREF(type);
+ type_to_sequence_map_.insert({type, check_result});
+ }
+ }
+
+ return check_result;
+ }
+
+ private:
+ std::function<int(PyObject*)> ternary_predicate_;
+ mutex type_to_sequence_map_mu_;
+ std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
+ GUARDED_BY(type_to_sequence_map_mu_);
+};
+
+// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsMappingHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ return PyObject_IsInstance(to_check, CollectionsMappingType);
+ });
+ if (PyDict_Check(o)) return true;
+ if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please call RegisterMappingClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
+ // We treat dicts and other mappings as special cases of sequences.
+ if (IsMappingHelper(o)) return true;
+ if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
+ LOG(WARNING) << "Sets are not currently considered sequences, "
+ "but this may change in the future, "
+ "so consider avoiding using them.";
+ WarnedThatSetIsNotSequence = true;
+ }
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please call RegisterSequenceClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
// Implements the same idea as tensorflow.util.nest._yield_value
// During construction we check if the iterable is a dictionary.
// If so, we construct a sequence from its sorted keys that will be used
@@ -96,7 +228,12 @@ string PyObjectToString(PyObject* o) {
// 'iterable' must not be modified while ValIterator is used.
class ValIterator {
public:
- explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+ explicit ValIterator(PyObject* iterable)
+ : dict_(nullptr),
+ mapping_(nullptr),
+ last_mapping_element_(nullptr),
+ seq_(nullptr),
+ index_(0) {
if (PyDict_Check(iterable)) {
dict_ = iterable;
// PyDict_Keys returns a list, which can be used with
@@ -108,6 +245,10 @@ class ValIterator {
// bugs caused by mixing ordered and plain dicts (e.g., flattening
// a dict but using a corresponding `OrderedDict` to pack it back).
PyList_Sort(seq_);
+ } else if (IsMappingHelper(iterable)) {
+ mapping_ = iterable;
+ seq_ = MappingKeys(iterable);
+ PyList_Sort(seq_);
} else {
seq_ = PySequence_Fast(iterable, "");
}
@@ -122,7 +263,9 @@ class ValIterator {
PyObject* element = nullptr;
if (index_ < size_) {
// Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references.
+ // references. For general mappings, ValIterator keeps a reference to the
+ // last retrieved element (and decrefs it before producing the next
+ // element) to abstract away the borrowed/new difference.
element = PySequence_Fast_GET_ITEM(seq_, index_);
++index_;
if (dict_ != nullptr) {
@@ -132,85 +275,32 @@ class ValIterator {
"Dictionary was modified during iteration over it");
return nullptr;
}
+ } else if (mapping_ != nullptr) {
+ element = PyObject_GetItem(mapping_, element);
+ if (element == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ return nullptr;
+ }
+ last_mapping_element_.reset(element);
}
}
return element;
}
private:
- PyObject* seq_;
+ // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+
+ // General mappings which have custom Python logic
+ PyObject* mapping_;
+ Safe_PyObjectPtr last_mapping_element_;
+
+ PyObject* seq_;
Py_ssize_t size_;
Py_ssize_t index_;
};
-mutex g_type_to_sequence_map(LINKER_INITIALIZED);
-std::unordered_map<PyTypeObject*, bool>* IsTypeSequenceMap() {
- static auto* const m = new std::unordered_map<PyTypeObject*, bool>;
- return m;
-}
-
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
-// Returns 0 otherwise.
-// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
- if (PyDict_Check(o)) return true;
- if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
- LOG(WARNING) << "Sets are not currently considered sequences, "
- "but this may change in the future, "
- "so consider avoiding using them.";
- WarnedThatSetIsNotSequence = true;
- }
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
-
- // Try not to return to Python - see if the type has already been seen
- // before.
-
- auto* type_to_sequence_map = IsTypeSequenceMap();
- auto* type = Py_TYPE(o);
-
- {
- mutex_lock l(g_type_to_sequence_map);
- auto it = type_to_sequence_map->find(type);
- if (it != type_to_sequence_map->end()) {
- return it->second;
- }
- }
-
- // NOTE: We explicitly release the g_type_to_sequence_map mutex,
- // because PyObject_IsInstance() may release the GIL, allowing another thread
- // concurrent entry to this function.
- int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
-
- bool is_sequence = static_cast<int>(is_instance != 0 && !IsString(o));
-
- // NOTE: This is never decref'd, but we don't want the type to get deleted
- // as long as it is in the map. This should not be too much of a
- // leak, as there should only be a relatively small number of types in the
- // map, and an even smaller number that are eligible for decref. As a
- // precaution, we limit the size of the map to 1024.
- {
- mutex_lock l(g_type_to_sequence_map);
- if (type_to_sequence_map->size() < kMaxItemsInCache) {
- Py_INCREF(type);
- type_to_sequence_map->insert({type, is_sequence});
- }
- }
-
- return is_sequence;
-}
-
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -226,21 +316,35 @@ int IsSequenceForDataHelper(PyObject* o) {
bool GetNextValuesForDict(PyObject* nested,
std::vector<Safe_PyObjectPtr>* next_values) {
- std::vector<PyObject*> result;
-
- PyObject* keys = PyDict_Keys(nested);
- if (PyList_Sort(keys) == -1) return false;
- Py_ssize_t size = PyList_Size(keys);
+ Safe_PyObjectPtr keys(PyDict_Keys(nested));
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
for (Py_ssize_t i = 0; i < size; ++i) {
// We know that key and item will not be deleted because nested owns
// a reference to them and callers of flatten must not modify nested
// while the method is running.
- PyObject* key = PyList_GET_ITEM(keys, i);
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
PyObject* item = PyDict_GetItem(nested, key);
Py_INCREF(item);
next_values->emplace_back(item);
}
- Py_DECREF(keys);
+ return true;
+}
+
+bool GetNextValuesForMapping(PyObject* nested,
+ std::vector<Safe_PyObjectPtr>* next_values) {
+ Safe_PyObjectPtr keys(MappingKeys(nested));
+ if (keys.get() == nullptr) {
+ return false;
+ }
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
+ for (Py_ssize_t i = 0; i < size; ++i) {
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* item = PyObject_GetItem(nested, key);
+ next_values->emplace_back(item);
+ }
return true;
}
@@ -265,6 +369,9 @@ bool GetNextValues(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
}
// iterate and recurse
return GetNextValuesForIterable(nested, next_values);
@@ -276,6 +383,9 @@ bool GetNextValuesForData(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
} else if (IsSparseTensorValueType(nested)) {
// if nested is a SparseTensorValue, just return itself as a single item
Py_INCREF(nested);
@@ -320,8 +430,8 @@ bool FlattenHelper(
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
- PyObject* k1 = PyDict_Keys(dict1);
- PyObject* k2 = PyDict_Keys(dict2);
+ PyObject* k1 = MappingKeys(dict1);
+ PyObject* k2 = MappingKeys(dict2);
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
@@ -423,6 +533,24 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
}
+ } else if (IsMappingHelper(o1)) {
+ // Fallback for custom mapping types. Instead of using PyDict methods
+ // which stay in C, we call iter(o1).
+ if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+
+ Safe_PyObjectPtr iter(PyObject_GetIter(o1));
+ PyObject* key;
+ while ((key = PyIter_Next(iter.get())) != nullptr) {
+ if (!PyMapping_HasKey(o2, key)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ Py_DECREF(key);
+ return true;
+ }
+ Py_DECREF(key);
+ }
}
}
@@ -470,6 +598,19 @@ void RegisterSequenceClass(PyObject* sequence_class) {
CollectionsSequenceType = sequence_class;
}
+void RegisterMappingClass(PyObject* mapping_class) {
+ if (!PyType_Check(mapping_class)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a class definition for `collections.Mapping`. Got ",
+ Py_TYPE(mapping_class)->tp_name)
+ .c_str());
+ return;
+ }
+ CollectionsMappingType = mapping_class;
+}
+
void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
if (!PyType_Check(sparse_tensor_value_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 70efc10c9a..41dcc969f8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -118,7 +118,9 @@ PyObject* Flatten(PyObject* nested);
// the type from the module. This approach also requires some trigger from
// Python so that we know that Python interpreter had been initialzied.
void RegisterSequenceClass(PyObject* sequence_class);
-// Similar to the above function, except for the
+// Like RegisterSequenceClass, but for collections.Mapping.
+void RegisterMappingClass(PyObject* mapping_class);
+// Similar to the above functions, except for the
// sparse_tensor.SparseTensorValue class.
void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 9f3b11b982..6ad1484295 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -31,6 +31,9 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSequenceClass;
%noexception tensorflow::swig::RegisterSequenceClass;
+%unignore tensorflow::swig::RegisterMappingClass;
+%noexception tensorflow::swig::RegisterMappingClass;
+
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;