aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/util/util.cc')
-rw-r--r--tensorflow/python/util/util.cc308
1 files changed, 228 insertions, 80 deletions
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index c79d8a8445..ad85a44f8d 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -31,6 +31,8 @@ namespace {
// Type object for collections.Sequence. This is set by RegisterSequenceClass.
PyObject* CollectionsSequenceType = nullptr;
+// Type object for collections.Mapping, set by RegisterMappingClass.
+PyObject* CollectionsMappingType = nullptr;
PyTypeObject* SparseTensorValueType = nullptr;
const int kMaxItemsInCache = 1024;
@@ -45,6 +47,23 @@ bool IsString(PyObject* o) {
PyUnicode_Check(o);
}
+// Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
+// and while we're at it give them consistent behavior by making sure the
+// returned value is a list.
+//
+// As with PyMapping_Keys, returns a new reference.
+PyObject* MappingKeys(PyObject* o) {
+#if PY_MAJOR_VERSION >= 3
+ return PyMapping_Keys(o);
+#else
+ static char key_method_name[] = "keys";
+ Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
+ return PySequence_Fast(
+ raw_result.get(),
+ "The '.keys()' method of a custom mapping returned a non-sequence.");
+#endif
+}
+
// Equivalent to Python's 'o.__class__.__name__'
// Note that '__class__' attribute is set only in new-style classes.
// A lot of tensorflow code uses __class__ without checks, so it seems like
@@ -85,6 +104,119 @@ string PyObjectToString(PyObject* o) {
}
}
+class CachedTypeCheck {
+ public:
+ explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
+ : ternary_predicate_(std::move(ternary_predicate)) {}
+
+ ~CachedTypeCheck() {
+ mutex_lock l(type_to_sequence_map_mu_);
+ for (const auto& pair : type_to_sequence_map_) {
+ Py_DECREF(pair.first);
+ }
+ }
+
+ // Caches successful executions of the one-argument (PyObject*) callable
+ // "ternary_predicate" based on the type of "o". -1 from the callable
+ // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
+ // does not match the predicate, and 1 indicates that it does. Used to avoid
+ // calling back into Python for expensive isinstance checks.
+ int CachedLookup(PyObject* o) {
+ // Try not to return to Python - see if the type has already been seen
+ // before.
+
+ auto* type = Py_TYPE(o);
+
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ auto it = type_to_sequence_map_.find(type);
+ if (it != type_to_sequence_map_.end()) {
+ return it->second;
+ }
+ }
+
+ int check_result = ternary_predicate_(o);
+
+ if (check_result == -1) {
+ return -1; // Type check error, not cached.
+ }
+
+ // NOTE: This is never decref'd as long as the object lives, which is likely
+ // forever, but we don't want the type to get deleted as long as it is in
+ // the map. This should not be too much of a leak, as there should only be a
+ // relatively small number of types in the map, and an even smaller number
+ // that are eligible for decref. As a precaution, we limit the size of the
+ // map to 1024.
+ {
+ mutex_lock l(type_to_sequence_map_mu_);
+ if (type_to_sequence_map_.size() < kMaxItemsInCache) {
+ Py_INCREF(type);
+ type_to_sequence_map_.insert({type, check_result});
+ }
+ }
+
+ return check_result;
+ }
+
+ private:
+ std::function<int(PyObject*)> ternary_predicate_;
+ mutex type_to_sequence_map_mu_;
+ std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
+ GUARDED_BY(type_to_sequence_map_mu_);
+};
+
+// Returns 1 if `o` is considered a mapping for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsMappingHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ return PyObject_IsInstance(to_check, CollectionsMappingType);
+ });
+ if (PyDict_Check(o)) return true;
+ if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please call RegisterMappingClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
+ // We treat dicts and other mappings as special cases of sequences.
+ if (IsMappingHelper(o)) return true;
+ if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
+ LOG(WARNING) << "Sets are not currently considered sequences, "
+ "but this may change in the future, "
+ "so consider avoiding using them.";
+ WarnedThatSetIsNotSequence = true;
+ }
+ if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
+ PyErr_SetString(
+ PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please call RegisterSequenceClass before using this module")
+ .c_str());
+ return -1;
+ }
+ return check_cache->CachedLookup(o);
+}
+
// Implements the same idea as tensorflow.util.nest._yield_value
// During construction we check if the iterable is a dictionary.
// If so, we construct a sequence from its sorted keys that will be used
@@ -96,7 +228,12 @@ string PyObjectToString(PyObject* o) {
// 'iterable' must not be modified while ValIterator is used.
class ValIterator {
public:
- explicit ValIterator(PyObject* iterable) : dict_(nullptr), index_(0) {
+ explicit ValIterator(PyObject* iterable)
+ : dict_(nullptr),
+ mapping_(nullptr),
+ last_mapping_element_(nullptr),
+ seq_(nullptr),
+ index_(0) {
if (PyDict_Check(iterable)) {
dict_ = iterable;
// PyDict_Keys returns a list, which can be used with
@@ -108,6 +245,10 @@ class ValIterator {
// bugs caused by mixing ordered and plain dicts (e.g., flattening
// a dict but using a corresponding `OrderedDict` to pack it back).
PyList_Sort(seq_);
+ } else if (IsMappingHelper(iterable)) {
+ mapping_ = iterable;
+ seq_ = MappingKeys(iterable);
+ PyList_Sort(seq_);
} else {
seq_ = PySequence_Fast(iterable, "");
}
@@ -122,7 +263,9 @@ class ValIterator {
PyObject* element = nullptr;
if (index_ < size_) {
// Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references.
+ // references. For general mappings, ValIterator keeps a reference to the
+ // last retrieved element (and decrefs it before producing the next
+ // element) to abstract away the borrowed/new difference.
element = PySequence_Fast_GET_ITEM(seq_, index_);
++index_;
if (dict_ != nullptr) {
@@ -132,85 +275,32 @@ class ValIterator {
"Dictionary was modified during iteration over it");
return nullptr;
}
+ } else if (mapping_ != nullptr) {
+ element = PyObject_GetItem(mapping_, element);
+ if (element == nullptr) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ return nullptr;
+ }
+ last_mapping_element_.reset(element);
}
}
return element;
}
private:
- PyObject* seq_;
+ // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+
+ // General mappings which have custom Python logic
+ PyObject* mapping_;
+ Safe_PyObjectPtr last_mapping_element_;
+
+ PyObject* seq_;
Py_ssize_t size_;
Py_ssize_t index_;
};
-mutex g_type_to_sequence_map(LINKER_INITIALIZED);
-std::unordered_map<PyTypeObject*, bool>* IsTypeSequenceMap() {
- static auto* const m = new std::unordered_map<PyTypeObject*, bool>;
- return m;
-}
-
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
-// Returns 0 otherwise.
-// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
- if (PyDict_Check(o)) return true;
- if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
- LOG(WARNING) << "Sets are not currently considered sequences, "
- "but this may change in the future, "
- "so consider avoiding using them.";
- WarnedThatSetIsNotSequence = true;
- }
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
-
- // Try not to return to Python - see if the type has already been seen
- // before.
-
- auto* type_to_sequence_map = IsTypeSequenceMap();
- auto* type = Py_TYPE(o);
-
- {
- mutex_lock l(g_type_to_sequence_map);
- auto it = type_to_sequence_map->find(type);
- if (it != type_to_sequence_map->end()) {
- return it->second;
- }
- }
-
- // NOTE: We explicitly release the g_type_to_sequence_map mutex,
- // because PyObject_IsInstance() may release the GIL, allowing another thread
- // concurrent entry to this function.
- int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
-
- bool is_sequence = static_cast<int>(is_instance != 0 && !IsString(o));
-
- // NOTE: This is never decref'd, but we don't want the type to get deleted
- // as long as it is in the map. This should not be too much of a
- // leak, as there should only be a relatively small number of types in the
- // map, and an even smaller number that are eligible for decref. As a
- // precaution, we limit the size of the map to 1024.
- {
- mutex_lock l(g_type_to_sequence_map);
- if (type_to_sequence_map->size() < kMaxItemsInCache) {
- Py_INCREF(type);
- type_to_sequence_map->insert({type, is_sequence});
- }
- }
-
- return is_sequence;
-}
-
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -226,21 +316,35 @@ int IsSequenceForDataHelper(PyObject* o) {
bool GetNextValuesForDict(PyObject* nested,
std::vector<Safe_PyObjectPtr>* next_values) {
- std::vector<PyObject*> result;
-
- PyObject* keys = PyDict_Keys(nested);
- if (PyList_Sort(keys) == -1) return false;
- Py_ssize_t size = PyList_Size(keys);
+ Safe_PyObjectPtr keys(PyDict_Keys(nested));
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
for (Py_ssize_t i = 0; i < size; ++i) {
// We know that key and item will not be deleted because nested owns
// a reference to them and callers of flatten must not modify nested
// while the method is running.
- PyObject* key = PyList_GET_ITEM(keys, i);
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
PyObject* item = PyDict_GetItem(nested, key);
Py_INCREF(item);
next_values->emplace_back(item);
}
- Py_DECREF(keys);
+ return true;
+}
+
+bool GetNextValuesForMapping(PyObject* nested,
+ std::vector<Safe_PyObjectPtr>* next_values) {
+ Safe_PyObjectPtr keys(MappingKeys(nested));
+ if (keys.get() == nullptr) {
+ return false;
+ }
+ if (PyList_Sort(keys.get()) == -1) return false;
+ Py_ssize_t size = PyList_Size(keys.get());
+ for (Py_ssize_t i = 0; i < size; ++i) {
+ PyObject* key = PyList_GET_ITEM(keys.get(), i);
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* item = PyObject_GetItem(nested, key);
+ next_values->emplace_back(item);
+ }
return true;
}
@@ -265,6 +369,9 @@ bool GetNextValues(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
}
// iterate and recurse
return GetNextValuesForIterable(nested, next_values);
@@ -276,6 +383,9 @@ bool GetNextValuesForData(PyObject* nested,
if (PyDict_Check(nested)) {
// if nested is dictionary, sort it by key and recurse on each value
return GetNextValuesForDict(nested, next_values);
+ } else if (IsMappingHelper(nested)) {
+ // same treatment as dictionaries, but for custom mapping types
+ return GetNextValuesForMapping(nested, next_values);
} else if (IsSparseTensorValueType(nested)) {
// if nested is a SparseTensorValue, just return itself as a single item
Py_INCREF(nested);
@@ -320,8 +430,8 @@ bool FlattenHelper(
// 'dict1' and 'dict2' are assumed to be Python dictionaries.
void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
bool* is_type_error) {
- PyObject* k1 = PyDict_Keys(dict1);
- PyObject* k2 = PyDict_Keys(dict2);
+ PyObject* k1 = MappingKeys(dict1);
+ PyObject* k2 = MappingKeys(dict2);
*is_type_error = false;
*error_msg = tensorflow::strings::StrCat(
"The two dictionaries don't have the same set of keys. "
@@ -394,7 +504,14 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
type2->tp_name);
return true;
}
- } else if (type1 != type2) {
+ } else if (type1 != type2
+ /* If both sequences are list types, don't complain. This allows
+ one to be a list subclass (e.g. _ListWrapper used for
+ automatic dependency tracking.) */
+ && !(PyList_Check(o1) && PyList_Check(o2))
+ /* Two mapping types will also compare equal, making _DictWrapper
+ and dict compare equal. */
+ && !(IsMappingHelper(o1) && IsMappingHelper(o2))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "
@@ -419,6 +536,24 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
return true;
}
}
+ } else if (IsMappingHelper(o1)) {
+ // Fallback for custom mapping types. Instead of using PyDict methods
+ // which stay in C, we call iter(o1).
+ if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ return true;
+ }
+
+ Safe_PyObjectPtr iter(PyObject_GetIter(o1));
+ PyObject* key;
+ while ((key = PyIter_Next(iter.get())) != nullptr) {
+ if (!PyMapping_HasKey(o2, key)) {
+ SetDifferentKeysError(o1, o2, error_msg, is_type_error);
+ Py_DECREF(key);
+ return true;
+ }
+ Py_DECREF(key);
+ }
}
}
@@ -466,6 +601,19 @@ void RegisterSequenceClass(PyObject* sequence_class) {
CollectionsSequenceType = sequence_class;
}
+void RegisterMappingClass(PyObject* mapping_class) {
+ if (!PyType_Check(mapping_class)) {
+ PyErr_SetString(
+ PyExc_TypeError,
+ tensorflow::strings::StrCat(
+ "Expecting a class definition for `collections.Mapping`. Got ",
+ Py_TYPE(mapping_class)->tp_name)
+ .c_str());
+ return;
+ }
+ CollectionsMappingType = mapping_class;
+}
+
void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
if (!PyType_Check(sparse_tensor_value_class)) {
PyErr_SetString(