aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/nest_test.py68
-rw-r--r--tensorflow/python/util/util.cc299
-rw-r--r--tensorflow/python/util/util.h4
-rw-r--r--tensorflow/python/util/util.i3
6 files changed, 284 insertions, 102 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index b5a0051c28..a7c60f5450 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -3408,6 +3408,7 @@ py_test(
":math_ops",
":util",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index d63f59a8c8..5aac559b9b 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -73,7 +73,7 @@ def _sequence_like(instance, args):
Returns:
`args` with the type of `instance`.
"""
- if isinstance(instance, dict):
+ if isinstance(instance, (dict, _collections.Mapping)):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -89,7 +89,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
- if isinstance(iterable, dict):
+ if isinstance(iterable, (dict, _collections.Mapping)):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by mixing
@@ -215,7 +215,7 @@ def flatten_dict_items(dictionary):
ValueError: If any key and value have not the same structure, or if keys are
not unique.
"""
- if not isinstance(dictionary, dict):
+ if not isinstance(dictionary, (dict, _collections.Mapping)):
raise TypeError("input must be a dictionary")
flat_dictionary = {}
for i, v in _six.iteritems(dictionary):
@@ -455,7 +455,7 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
"structure has length %s, while shallow structure has length %s."
% (len(input_tree), len(shallow_tree)))
- if check_types and isinstance(shallow_tree, dict):
+ if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
if set(input_tree) != set(shallow_tree):
raise ValueError(
"The two structures don't have the same keys. Input "
@@ -716,7 +716,7 @@ def yield_flat_paths(nest):
# The _maybe_add_final_path_element function is used below in order to avoid
# adding trailing slashes when the sub-element recursed into is a leaf.
- if isinstance(nest, dict):
+ if isinstance(nest, (dict, _collections.Mapping)):
for key in _sorted(nest):
value = nest[key]
for sub_path in yield_flat_paths(value):
@@ -760,3 +760,4 @@ def flatten_with_joined_string_paths(structure, separator="/"):
_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
+_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 2f12b25354..26c6ea4b01 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import time
+from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -33,7 +34,22 @@ from tensorflow.python.platform import test
from tensorflow.python.util import nest
-class NestTest(test.TestCase):
+class _CustomMapping(collections.Mapping):
+
+ def __init__(self, *args, **kwargs):
+ self._wrapped = dict(*args, **kwargs)
+
+ def __getitem__(self, key):
+ return self._wrapped[key]
+
+ def __iter__(self):
+ return iter(self._wrapped)
+
+ def __len__(self):
+ return len(self._wrapped)
+
+
+class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
@@ -72,26 +88,32 @@ class NestTest(test.TestCase):
with self.assertRaises(ValueError):
nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"])
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
@test_util.assert_no_new_pyobjects_executing_eagerly
- def testFlattenDictOrder(self):
+ def testFlattenDictOrder(self, mapping_type):
"""`flatten` orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
+ ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)])
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = nest.flatten(ordered)
plain_flat = nest.flatten(plain)
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
- def testPackDictOrder(self):
+ @parameterized.parameters({"mapping_type": collections.OrderedDict},
+ {"mapping_type": _CustomMapping})
+ def testPackDictOrder(self, mapping_type):
"""Packing orders dicts by key, including OrderedDicts."""
- ordered = collections.OrderedDict([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
+ custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)])
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
- ordered_reconstruction = nest.pack_sequence_as(ordered, seq)
+ custom_reconstruction = nest.pack_sequence_as(custom, seq)
plain_reconstruction = nest.pack_sequence_as(plain, seq)
+ self.assertIsInstance(custom_reconstruction, mapping_type)
+ self.assertIsInstance(plain_reconstruction, dict)
self.assertEqual(
- collections.OrderedDict([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
- ordered_reconstruction)
+ mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]),
+ custom_reconstruction)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name
@@ -101,8 +123,10 @@ class NestTest(test.TestCase):
# A nice messy mix of tuples, lists, dicts, and `OrderedDict`s.
mess = [
"z",
- NestTest.Abc(3, 4),
- {
+ NestTest.Abc(3, 4), {
+ "d": _CustomMapping({
+ 41: 4
+ }),
"c": [
1,
collections.OrderedDict([
@@ -111,17 +135,19 @@ class NestTest(test.TestCase):
]),
],
"b": 5
- },
- 17
+ }, 17
]
flattened = nest.flatten(mess)
- self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 17])
+ self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17])
structure_of_mess = [
14,
NestTest.Abc("a", True),
{
+ "d": _CustomMapping({
+ 41: 42
+ }),
"c": [
0,
collections.OrderedDict([
@@ -142,6 +168,10 @@ class NestTest(test.TestCase):
self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict)
self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"])
+ unflattened_custom_mapping = unflattened[2]["d"]
+ self.assertIsInstance(unflattened_custom_mapping, _CustomMapping)
+ self.assertEqual(list(unflattened_custom_mapping.keys()), [41])
+
def testFlatten_numpyIsNotFlattened(self):
structure = np.array([1, 2, 3])
flattened = nest.flatten(structure)
@@ -179,19 +209,23 @@ class NestTest(test.TestCase):
self.assertFalse(nest.is_sequence(math_ops.tanh(ones)))
self.assertFalse(nest.is_sequence(np.ones((4, 5))))
- def testFlattenDictItems(self):
- dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
+ @parameterized.parameters({"mapping_type": _CustomMapping},
+ {"mapping_type": dict})
+ def testFlattenDictItems(self, mapping_type):
+ dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))})
flat = {4: "a", 5: "b", 6: "c", 8: "d"}
self.assertEqual(nest.flatten_dict_items(dictionary), flat)
with self.assertRaises(TypeError):
nest.flatten_dict_items(4)
- bad_dictionary = {(4, 5, (4, 8)): ("a", "b", ("c", "d"))}
+ bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))})
with self.assertRaisesRegexp(ValueError, "not unique"):
nest.flatten_dict_items(bad_dictionary)
- another_bad_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))}
+ another_bad_dictionary = mapping_type({
+ (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e")))
+ })
with self.assertRaisesRegexp(
ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"):
nest.flatten_dict_items(another_bad_dictionary)
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 366f8a0deb..f9e0b7e4d2 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -31,6 +31,8 @@ namespace {
// Type object for collections.Sequence. This is set by RegisterSequenceClass.
PyObject* CollectionsSequenceType = nullptr;
+// Type object for collections.Mapping, set by RegisterMappingClass.
+PyObject* CollectionsMappingType = nullptr;
PyTypeObject* SparseTensorValueType = nullptr;
const int kMaxItemsInCache = 1024;
@@ -45,6 +47,23 @@ bool IsString(PyObject* o) {
PyUnicode_Check(o);
}
+// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
+// and while we're at it give them consistent behavior by making sure the
+// returned value is a list.
+//
+// As with PyMapping_Keys, returns a new reference.
+PyObject* MappingKeys(PyObject* o) {
+#if PY_MAJOR_VERSION >= 3
+ return PyMapping_Keys(o);
+#else
+ static char key_method_name[] = "keys";
+ Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
+ return PySequence_Fast(
+ raw_result.get(),
+ "The '.keys()' method of a custom mapping returned a non-sequence.");
+#endif
+}
+
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
@@ -85,6 +104,119 @@ string PyObjectToString(PyObject* o) {
}
}
+class CachedTypeCheck {
+ public:
+ explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
+ : ternary_predicate_(std::move(ternary_predicate)) {}
+
+ ~CachedTypeCheck() {
+ mutex_lock l(type_to_sequence_map_mu_);
+ for (const auto& pair : type_to_sequence_map_) {
+ Py_DECREF(pair.first);
+ }
+ }
+
+ // Caches successful executions of the one-argument (PyObject*) callable
+ // "ternary_predicate" based on the type of "o". -1 from the callable
+ // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
+ // does not match the predicate, and 1 indicates that it does. Used to avoid
+ // calling back into Python for expensive isinstance checks.
+ int CachedLookup(PyObject* o) {
+ // Try not to return to Python - see if the type has already been seen
+ // before.
+
+ auto* type = Py_TYPE(o);
+
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ auto it = type_to_sequence_map_.find(type);
+ if (it != type_to_sequence_map_.end()) {
+ return it->second;
+ }
+ }
+
+ int check_result = ternary_predicate_(o);
+
+ if (check_result == -1) {
+ return -1; // Type check error, not cached.
+ }
+
+ // NOTE: This is never decref'd as long as the object lives, which is likely
+ // forever, but we don't want the type to get deleted as long as it is in
+ // the map. This should not be too much of a leak, as there should only be a
+ // relatively small number of types in the map, and an even smaller number
+ // that are eligible for decref. As a precaution, we limit the size of the
+ // map to 1024.
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ if (type_to_sequence_map_.size() < kMaxItemsInCache) {
+ Py_INCREF(type);
+ type_to_sequence_map_.insert({type, check_result});
+ }
+ }
+
+ return check_result;
+ }
+
+ private:
+ std::function<int(PyObject*)> ternary_predicate_;
+ mutex type_to_sequence_map_mu_;
+ std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
+ GUARDED_BY(type_to_sequence_map_mu_);
+};
+
+// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsMappingHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ return PyObject_IsInstance(to_check, CollectionsMappingType);
+ });
+ if (PyDict_Check(o)) return true;
+ if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please call RegisterMappingClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
+ // We treat dicts and other mappings as special cases of sequences.
+ if (IsMappingHelper(o)) return true;
+ if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
+ LOG(WARNING) << "Sets are not currently considered sequences, "
+ "but this may change in the future, "
+ "so consider avoiding using them.";
+ WarnedThatSetIsNotSequence = true;
+ }
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please call RegisterSequenceClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
// Implements the same idea as tensorflow.util.nest._yield_value
// During construction we check if the iterable is a dictionary.
// If so, we construct a sequence from its sorted keys that will be used
@@ -96,7 +228,12 @@ string PyObjectToString(PyObject* o) {
// 'iterable' must not be modified while ValIterator is used.
class ValIterator {
public:
- explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+ explicit ValIterator(PyObject* iterable)
+ : dict_(nullptr),
+ mapping_(nullptr),
+ last_mapping_element_(nullptr),
+ seq_(nullptr),
+ index_(0) {
if (PyDict_Check(iterable)) {
dict_ = iterable;
// PyDict_Keys returns a list, which can be used with
@@ -108,6 +245,10 @@ class ValIterator {
// bugs caused by mixing ordered and plain dicts (e.g., flattening
// a dict but using a corresponding `OrderedDict` to pack it back).
PyList_Sort(seq_);
+ } else if (IsMappingHelper(iterable)) {
+ mapping_ = iterable;
+ seq_ = MappingKeys(iterable);
+ PyList_Sort(seq_);
} else {
seq_ = PySequence_Fast(iterable, "");
}
@@ -122,7 +263,9 @@ class ValIterator {
PyObject* element = nullptr;
if (index_ < size_) {
// Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references.
+ // references. For general mappings, ValIterator keeps a reference to the
+ // last retrieved element (and decrefs it before producing the next
+ // element) to abstract away the borrowed/new difference.
element = PySequence_Fast_GET_ITEM(seq_, index_);
++index_;
if (dict_ != nullptr) {
@@ -132,85 +275,32 @@ class ValIterator {
"Dictionary was modified during iteration over it");
return nullptr;
}
+ } else if (mapping_ != nullptr) {
+ element = PyObject_GetItem(mapping_, element);
+ if (element == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ return nullptr;
+ }
+ last_mapping_element_.reset(element);
}
}
return element;
}
private:
- PyObject* seq_;
+ // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+
+ // General mappings which have custom Python logic
+ PyObject* mapping_;
+ Safe_PyObjectPtr last_mapping_element_;
+
+ PyObject* seq_;
Py_ssize_t size_;
Py_ssize_t index_;
};
-mutex g_type_to_sequence_map(LINKER_INITIALIZED);
-std::unordered_map<PyTypeObject*, bool>* IsTypeSequenceMap() {
- static auto* const m = new std::unordered_map<PyTypeObject*, bool>;
- return m;
-}
-
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
-// Returns 0 otherwise.
-// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
- if (PyDict_Check(o)) return true;
- if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
- LOG(WARNING) << "Sets are not currently considered sequences, "
- "but this may change in the future, "
- "so consider avoiding using them.";
- WarnedThatSetIsNotSequence = true;
- }
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
-
- // Try not to return to Python - see if the type has already been seen
- // before.
-
- auto* type_to_sequence_map = IsTypeSequenceMap();
- auto* type = Py_TYPE(o);
-
- {
- mutex_lock l(g_type_to_sequence_map);
- auto it = type_to_sequence_map->find(type);
- if (it != type_to_sequence_map->end()) {
- return it->second;
- }
- }
-
- // NOTE: We explicitly release the g_type_to_sequence_map mutex,
- // because PyObject_IsInstance() may release the GIL, allowing another thread
- // concurrent entry to this function.
- int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
-
- bool is_sequence = static_cast<int>(is_instance != 0 && !IsString(o));
-
- // NOTE: This is never decref'd, but we don't want the type to get deleted
- // as long as it is in the map. This should not be too much of a
- // leak, as there should only be a relatively small number of types in the
- // map, and an even smaller number that are eligible for decref. As a
- // precaution, we limit the size of the map to 1024.
- {
- mutex_lock l(g_type_to_sequence_map);
- if (type_to_sequence_map->size() < kMaxItemsInCache) {
- Py_INCREF(type);
- type_to_sequence_map->insert({type, is_sequence});
- }
- }
-
- return is_sequence;
-}
-
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -226,21 +316,35 @@ int IsSequenceForDataHelper(PyObject* o) {
bool GetNextValuesForDict(PyObject* nested,
std::vector<Safe_PyObjectPtr>* next_values) {
- std::vector<PyObject*> result;
-
- PyObject* keys = PyDict_Keys(nested);
- if (PyList_Sort(keys) == -1) return false;
- Py_ssize_t size = PyList_Size(keys);
+ Safe_PyObjectPtr keys(PyDict_Keys(nested));
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
for (Py_ssize_t i = 0; i < size; ++i) {
// We know that key and item will not be deleted because nested owns
// a reference to them and callers of flatten must not modify nested
// while the method is running.
- PyObject* key = PyList_GET_ITEM(keys, i);
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
PyObject* item = PyDict_GetItem(nested, key);
Py_INCREF(item);
next_values->emplace_back(item);
}
- Py_DECREF(keys);
+ return true;
+}
+
+bool GetNextValuesForMapping(PyObject* nested,
+ std::vector<Safe_PyObjectPtr>* next_values) {
+ Safe_PyObjectPtr keys(MappingKeys(nested));
+ if (keys.get() == nullptr) {
+ return false;
+ }
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
+ for (Py_ssize_t i = 0; i < size; ++i) {
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* item = PyObject_GetItem(nested, key);
+ next_values->emplace_back(item);
+ }
return true;
}
@@ -265,6 +369,9 @@ bool GetNextValues(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
}
// iterate and recurse
return GetNextValuesForIterable(nested, next_values);
@@ -276,6 +383,9 @@ bool GetNextValuesForData(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
} else if (IsSparseTensorValueType(nested)) {
// if nested is a SparseTensorValue, just return itself as a single item
Py_INCREF(nested);
@@ -320,8 +430,8 @@ bool FlattenHelper(
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
- PyObject* k1 = PyDict_Keys(dict1);
- PyObject* k2 = PyDict_Keys(dict2);
+ PyObject* k1 = MappingKeys(dict1);
+ PyObject* k2 = MappingKeys(dict2);
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
@@ -423,6 +533,24 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
}
+ } else if (IsMappingHelper(o1)) {
+ // Fallback for custom mapping types. Instead of using PyDict methods
+ // which stay in C, we call iter(o1).
+ if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+
+ Safe_PyObjectPtr iter(PyObject_GetIter(o1));
+ PyObject* key;
+ while ((key = PyIter_Next(iter.get())) != nullptr) {
+ if (!PyMapping_HasKey(o2, key)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ Py_DECREF(key);
+ return true;
+ }
+ Py_DECREF(key);
+ }
}
}
@@ -470,6 +598,19 @@ void RegisterSequenceClass(PyObject* sequence_class) {
CollectionsSequenceType = sequence_class;
}
+void RegisterMappingClass(PyObject* mapping_class) {
+ if (!PyType_Check(mapping_class)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a class definition for `collections.Mapping`. Got ",
+ Py_TYPE(mapping_class)->tp_name)
+ .c_str());
+ return;
+ }
+ CollectionsMappingType = mapping_class;
+}
+
void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
if (!PyType_Check(sparse_tensor_value_class)) {
PyErr_SetString(
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 70efc10c9a..41dcc969f8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -118,7 +118,9 @@ PyObject* Flatten(PyObject* nested);
// the type from the module. This approach also requires some trigger from
// Python so that we know that Python interpreter had been initialzied.
void RegisterSequenceClass(PyObject* sequence_class);
-// Similar to the above function, except for the
+// Like RegisterSequenceClass, but for collections.Mapping.
+void RegisterMappingClass(PyObject* mapping_class);
+// Similar to the above functions, except for the
// sparse_tensor.SparseTensorValue class.
void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 9f3b11b982..6ad1484295 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -31,6 +31,9 @@ limitations under the License.
%unignore tensorflow::swig::RegisterSequenceClass;
%noexception tensorflow::swig::RegisterSequenceClass;
+%unignore tensorflow::swig::RegisterMappingClass;
+%noexception tensorflow::swig::RegisterMappingClass;
+
%unignore tensorflow::swig::RegisterSparseTensorValueClass;
%noexception tensorflow::swig::RegisterSparseTensorValueClass;