diff options
author | Allen Lavoie <allenl@google.com> | 2018-08-01 11:05:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-01 11:16:33 -0700 |
commit | 72023ba74dd6f18a215925c4d1bc86c9f25f6a7a (patch) | |
tree | 29ea5e6b0da86da4544e4bd467db7e4531310322 /tensorflow/python/util | |
parent | 24331e4fc0166d2adf3cf3b155844b5c77500a0c (diff) |
Fix custom mapping comparisons in the nest utility
Previously it checked that one object was a mapping before using PyDict methods on the second one being compared. Now checks that both are dicts and falls back to PyMapping methods otherwise.
Also fixes exception propagation for this section of code.
PiperOrigin-RevId: 206957412
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/nest_test.py | 4 | ||||
-rw-r--r-- | tensorflow/python/util/util.cc | 37 |
2 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py index fd75c6885a..2369eb610e 100644 --- a/tensorflow/python/util/nest_test.py +++ b/tensorflow/python/util/nest_test.py @@ -354,6 +354,10 @@ class NestTest(parameterized.TestCase, test.TestCase): EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name + def testHeterogeneousComparison(self): + nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) + nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) + @test_util.assert_no_new_pyobjects_executing_eagerly def testMapStructure(self): structure1 = (((1, 2), 3), 4, (5, 6)) diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index ad85a44f8d..ebb72079ef 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -52,12 +52,17 @@ bool IsString(PyObject* o) { // returned value is a list. // // As with PyMapping_Keys, returns a new reference. +// +// On failure, returns nullptr. PyObject* MappingKeys(PyObject* o) { #if PY_MAJOR_VERSION >= 3 return PyMapping_Keys(o); #else static char key_method_name[] = "keys"; Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr)); + if (PyErr_Occurred() || raw_result.get() == nullptr) { + return nullptr; + } return PySequence_Fast( raw_result.get(), "The '.keys()' method of a custom mapping returned a non-sequence."); @@ -260,6 +265,9 @@ class ValIterator { // Return a borrowed reference to the next element from iterable. // Return nullptr when iteration is over. PyObject* next() { + if (TF_PREDICT_FALSE(seq_ == nullptr)) { + return nullptr; + } PyObject* element = nullptr; if (index_ < size_) { // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed @@ -430,16 +438,26 @@ bool FlattenHelper( // 'dict1' and 'dict2' are assumed to be Python dictionaries. void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg, bool* is_type_error) { - PyObject* k1 = MappingKeys(dict1); - PyObject* k2 = MappingKeys(dict2); + Safe_PyObjectPtr k1(MappingKeys(dict1)); + if (PyErr_Occurred() || k1.get() == nullptr) { + *error_msg = + ("The two dictionaries don't have the same set of keys. Failed to " + "fetch keys."); + return; + } + Safe_PyObjectPtr k2(MappingKeys(dict2)); + if (PyErr_Occurred() || k2.get() == nullptr) { + *error_msg = + ("The two dictionaries don't have the same set of keys. Failed to " + "fetch keys."); + return; + } *is_type_error = false; *error_msg = tensorflow::strings::StrCat( "The two dictionaries don't have the same set of keys. " "First structure has keys ", - PyObjectToString(k1), ", while second structure has keys ", - PyObjectToString(k2)); - Py_DECREF(k1); - Py_DECREF(k2); + PyObjectToString(k1.get()), ", while second structure has keys ", + PyObjectToString(k2.get())); } // Returns true iff there were no "internal" errors. In other words, @@ -522,7 +540,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, return true; } - if (PyDict_Check(o1)) { + if (PyDict_Check(o1) && PyDict_Check(o2)) { if (PyDict_Size(o1) != PyDict_Size(o2)) { SetDifferentKeysError(o1, o2, error_msg, is_type_error); return true; @@ -741,6 +759,11 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) { string error_msg; bool is_type_error = false; AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error); + if (PyErr_Occurred()) { + // Don't hide Python exceptions while checking (e.g. errors fetching keys + // from custom mappings). + return nullptr; + } if (!error_msg.empty()) { PyErr_SetString( is_type_error ? PyExc_TypeError : PyExc_ValueError, |