aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-08-01 11:05:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-01 11:16:33 -0700
commit72023ba74dd6f18a215925c4d1bc86c9f25f6a7a (patch)
tree29ea5e6b0da86da4544e4bd467db7e4531310322 /tensorflow/python/util
parent24331e4fc0166d2adf3cf3b155844b5c77500a0c (diff)
Fix custom mapping comparisons in the nest utility
Previously it checked that one object was a mapping before using PyDict methods on the second one being compared. Now checks that both are dicts and falls back to PyMapping methods otherwise. Also fixes exception propagation for this section of code. PiperOrigin-RevId: 206957412
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/nest_test.py4
-rw-r--r--tensorflow/python/util/util.cc37
2 files changed, 34 insertions, 7 deletions
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index fd75c6885a..2369eb610e 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -354,6 +354,10 @@ class NestTest(parameterized.TestCase, test.TestCase):
EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name
+ def testHeterogeneousComparison(self):
+ nest.assert_same_structure({"a": 4}, _CustomMapping(a=3))
+ nest.assert_same_structure(_CustomMapping(b=3), {"b": 4})
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testMapStructure(self):
structure1 = (((1, 2), 3), 4, (5, 6))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index ad85a44f8d..ebb72079ef 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -52,12 +52,17 @@ bool IsString(PyObject* o) {
// returned value is a list.
//
// As with PyMapping_Keys, returns a new reference.
+//
+// On failure, returns nullptr.
PyObject* MappingKeys(PyObject* o) {
#if PY_MAJOR_VERSION >= 3
return PyMapping_Keys(o);
#else
static char key_method_name[] = "keys";
Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
+ if (PyErr_Occurred() || raw_result.get() == nullptr) {
+ return nullptr;
+ }
return PySequence_Fast(
raw_result.get(),
"The '.keys()' method of a custom mapping returned a non-sequence.");
@@ -260,6 +265,9 @@ class ValIterator {
// Return a borrowed reference to the next element from iterable.
// Return nullptr when iteration is over.
PyObject* next() {
+ if (TF_PREDICT_FALSE(seq_ == nullptr)) {
+ return nullptr;
+ }
PyObject* element = nullptr;
if (index_ < size_) {
// Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
@@ -430,16 +438,26 @@ bool FlattenHelper(
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
- PyObject* k1 = MappingKeys(dict1);
- PyObject* k2 = MappingKeys(dict2);
+ Safe_PyObjectPtr k1(MappingKeys(dict1));
+ if (PyErr_Occurred() || k1.get() == nullptr) {
+ *error_msg =
+ ("The two dictionaries don't have the same set of keys. Failed to "
+ "fetch keys.");
+ return;
+ }
+ Safe_PyObjectPtr k2(MappingKeys(dict2));
+ if (PyErr_Occurred() || k2.get() == nullptr) {
+ *error_msg =
+ ("The two dictionaries don't have the same set of keys. Failed to "
+ "fetch keys.");
+ return;
+ }
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
"First structure has keys ",
- PyObjectToString(k1), ", while second structure has keys ",
- PyObjectToString(k2));
- Py_DECREF(k1);
- Py_DECREF(k2);
+ PyObjectToString(k1.get()), ", while second structure has keys ",
+ PyObjectToString(k2.get()));
}
// Returns true iff there were no "internal" errors. In other words,
@@ -522,7 +540,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
- if (PyDict_Check(o1)) {
+ if (PyDict_Check(o1) && PyDict_Check(o2)) {
if (PyDict_Size(o1) != PyDict_Size(o2)) {
SetDifferentKeysError(o1, o2, error_msg, is_type_error);
return true;
@@ -741,6 +759,11 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types) {
string error_msg;
bool is_type_error = false;
AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error);
+ if (PyErr_Occurred()) {
+ // Don't hide Python exceptions while checking (e.g. errors fetching keys
+ // from custom mappings).
+ return nullptr;
+ }
if (!error_msg.empty()) {
PyErr_SetString(
is_type_error ? PyExc_TypeError : PyExc_ValueError,