diff options
Diffstat (limited to 'tensorflow/python/util/util.cc')
-rw-r--r-- | tensorflow/python/util/util.cc | 305 |
1 files changed, 146 insertions, 159 deletions
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 562bbdcfeb..2087957b31 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -15,9 +15,11 @@ limitations under the License. #include "tensorflow/python/util/util.h" #include <functional> +#include <memory> #include <unordered_map> #include <vector> +#include "absl/memory/memory.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -222,93 +224,136 @@ int IsSequenceHelper(PyObject* o) { return check_cache->CachedLookup(o); } -// Implements the same idea as tensorflow.util.nest._yield_value -// During construction we check if the iterable is a dictionary. -// If so, we construct a sequence from its sorted keys that will be used -// for iteration. -// If not, we construct a sequence directly from the iterable. -// At each step, we get the next element from the sequence and use it -// either as a key or return it directly. -// -// 'iterable' must not be modified while ValIterator is used. -class ValIterator { +// ValueIterator interface +class ValueIterator { + public: + virtual ~ValueIterator() {} + virtual Safe_PyObjectPtr next() = 0; + + bool valid() const { return is_valid_; } + + protected: + void invalidate() { is_valid_ = false; } + + private: + bool is_valid_ = true; +}; + +using ValueIteratorPtr = std::unique_ptr<ValueIterator>; + +// Iterate through dictionaries in a deterministic order by sorting the +// keys. Notice this means that we ignore the original order of +// `OrderedDict` instances. This is intentional, to avoid potential +// bugs caused by mixing ordered and plain dicts (e.g., flattening +// a dict but using a corresponding `OrderedDict` to pack it back). +class DictValueIterator : public ValueIterator { public: - explicit ValIterator(PyObject* iterable) - : dict_(nullptr), - mapping_(nullptr), - last_mapping_element_(nullptr), - seq_(nullptr), - index_(0) { - if (PyDict_Check(iterable)) { - dict_ = iterable; - // PyDict_Keys returns a list, which can be used with - // PySequence_Fast_GET_ITEM. - seq_ = PyDict_Keys(iterable); - // Iterate through dictionaries in a deterministic order by sorting the - // keys. Notice this means that we ignore the original order of - // `OrderedDict` instances. This is intentional, to avoid potential - // bugs caused by mixing ordered and plain dicts (e.g., flattening - // a dict but using a corresponding `OrderedDict` to pack it back). - PyList_Sort(seq_); - } else if (IsMappingHelper(iterable)) { - mapping_ = iterable; - seq_ = MappingKeys(iterable); - PyList_Sort(seq_); + explicit DictValueIterator(PyObject* dict) + : dict_(dict), keys_(PyDict_Keys(dict)) { + if (PyList_Sort(keys_.get()) == -1) { + invalidate(); } else { - seq_ = PySequence_Fast(iterable, ""); + iter_.reset(PyObject_GetIter(keys_.get())); } - size_ = PySequence_Fast_GET_SIZE(seq_); } - ~ValIterator() { Py_DECREF(seq_); } - - // Return a borrowed reference to the next element from iterable. - // Return nullptr when iteration is over. - PyObject* next() { - if (TF_PREDICT_FALSE(seq_ == nullptr)) { - return nullptr; - } - PyObject* element = nullptr; - if (index_ < size_) { - // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed - // references. For general mappings, ValIterator keeps a reference to the - // last retrieved element (and decrefs it before producing the next - // element) to abstract away the borrowed/new difference. - element = PySequence_Fast_GET_ITEM(seq_, index_); - ++index_; - if (dict_ != nullptr) { - element = PyDict_GetItem(dict_, element); - if (element == nullptr) { - PyErr_SetString(PyExc_RuntimeError, - "Dictionary was modified during iteration over it"); - return nullptr; - } - } else if (mapping_ != nullptr) { - element = PyObject_GetItem(mapping_, element); - if (element == nullptr) { - PyErr_SetString(PyExc_RuntimeError, - "Mapping was modified during iteration over it"); - return nullptr; - } - last_mapping_element_.reset(element); + Safe_PyObjectPtr next() override { + Safe_PyObjectPtr result; + Safe_PyObjectPtr key(PyIter_Next(iter_.get())); + if (key) { + // PyDict_GetItem returns a borrowed reference. + PyObject* elem = PyDict_GetItem(dict_, key.get()); + if (elem) { + Py_INCREF(elem); + result.reset(elem); + } else { + PyErr_SetString(PyExc_RuntimeError, + "Dictionary was modified during iteration over it"); } } - return element; + return result; } private: - // Special casing for things that pass PyDict_Check (faster, no Python calls) PyObject* dict_; + Safe_PyObjectPtr keys_; + Safe_PyObjectPtr iter_; +}; - // General mappings which have custom Python logic +// Iterate over mapping objects by sorting the keys first +class MappingValueIterator : public ValueIterator { + public: + explicit MappingValueIterator(PyObject* mapping) + : mapping_(mapping), keys_(MappingKeys(mapping)) { + if (!keys_ || PyList_Sort(keys_.get()) == -1) { + invalidate(); + } else { + iter_.reset(PyObject_GetIter(keys_.get())); + } + } + + Safe_PyObjectPtr next() override { + Safe_PyObjectPtr result; + Safe_PyObjectPtr key(PyIter_Next(iter_.get())); + if (key) { + // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. + PyObject* elem = PyObject_GetItem(mapping_, key.get()); + if (elem) { + result.reset(elem); + } else { + PyErr_SetString(PyExc_RuntimeError, + "Mapping was modified during iteration over it"); + } + } + return result; + } + + private: PyObject* mapping_; - Safe_PyObjectPtr last_mapping_element_; + Safe_PyObjectPtr keys_; + Safe_PyObjectPtr iter_; +}; + +// Iterate over a sequence, by index. +class SequenceValueIterator : public ValueIterator { + public: + explicit SequenceValueIterator(PyObject* iterable) + : seq_(PySequence_Fast(iterable, "")), + size_(PySequence_Fast_GET_SIZE(seq_.get())), + index_(0) {} - PyObject* seq_; - Py_ssize_t size_; + Safe_PyObjectPtr next() override { + Safe_PyObjectPtr result; + if (index_ < size_) { + // PySequence_Fast_GET_ITEM returns a borrowed reference. + PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_); + ++index_; + Py_INCREF(elem); + result.reset(elem); + } + + return result; + } + + private: + Safe_PyObjectPtr seq_; + const Py_ssize_t size_; Py_ssize_t index_; }; +// Just return itself as a single item. +class SparseTensorValueIterator : public ValueIterator { + public: + explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) { + Py_INCREF(tensor); + } + + Safe_PyObjectPtr next() override { return std::move(tensor_); } + + private: + Safe_PyObjectPtr tensor_; +}; + bool IsSparseTensorValueType(PyObject* o) { if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { return false; @@ -322,93 +367,33 @@ int IsSequenceForDataHelper(PyObject* o) { !IsSparseTensorValueType(o); } -bool GetNextValuesForDict(PyObject* nested, - std::vector<Safe_PyObjectPtr>* next_values) { - Safe_PyObjectPtr keys(PyDict_Keys(nested)); - if (PyList_Sort(keys.get()) == -1) return false; - Py_ssize_t size = PyList_Size(keys.get()); - for (Py_ssize_t i = 0; i < size; ++i) { - // We know that key and item will not be deleted because nested owns - // a reference to them and callers of flatten must not modify nested - // while the method is running. - PyObject* key = PyList_GET_ITEM(keys.get(), i); - PyObject* item = PyDict_GetItem(nested, key); - Py_INCREF(item); - next_values->emplace_back(item); - } - return true; -} - -bool GetNextValuesForMapping(PyObject* nested, - std::vector<Safe_PyObjectPtr>* next_values) { - Safe_PyObjectPtr keys(MappingKeys(nested)); - if (keys.get() == nullptr) { - return false; - } - if (PyList_Sort(keys.get()) == -1) return false; - Py_ssize_t size = PyList_Size(keys.get()); - for (Py_ssize_t i = 0; i < size; ++i) { - PyObject* key = PyList_GET_ITEM(keys.get(), i); - // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference. - PyObject* item = PyObject_GetItem(nested, key); - next_values->emplace_back(item); - } - return true; -} - -bool GetNextValuesForIterable(PyObject* nested, - std::vector<Safe_PyObjectPtr>* next_values) { - PyObject* item; - PyObject* iterator = PyObject_GetIter(nested); - if (iterator == nullptr || PyErr_Occurred()) { - return false; - } - while ((item = PyIter_Next(iterator)) != nullptr) { - next_values->emplace_back(item); - } - Py_DECREF(iterator); - return true; -} - -// GetNextValues returns the values that the FlattenHelper function will recurse -// over next. -bool GetNextValues(PyObject* nested, - std::vector<Safe_PyObjectPtr>* next_values) { +ValueIteratorPtr GetValueIterator(PyObject* nested) { if (PyDict_Check(nested)) { - // if nested is dictionary, sort it by key and recurse on each value - return GetNextValuesForDict(nested, next_values); + return absl::make_unique<DictValueIterator>(nested); } else if (IsMappingHelper(nested)) { - // same treatment as dictionaries, but for custom mapping types - return GetNextValuesForMapping(nested, next_values); + return absl::make_unique<MappingValueIterator>(nested); + } else { + return absl::make_unique<SequenceValueIterator>(nested); } - // iterate and recurse - return GetNextValuesForIterable(nested, next_values); } -// Similar to above, just specialized for the functions in the data pacakage. -bool GetNextValuesForData(PyObject* nested, - std::vector<Safe_PyObjectPtr>* next_values) { +// Similar to above, just specialized for the functions in the data package. +ValueIteratorPtr GetValueIteratorForData(PyObject* nested) { if (PyDict_Check(nested)) { - // if nested is dictionary, sort it by key and recurse on each value - return GetNextValuesForDict(nested, next_values); + return absl::make_unique<DictValueIterator>(nested); } else if (IsMappingHelper(nested)) { - // same treatment as dictionaries, but for custom mapping types - return GetNextValuesForMapping(nested, next_values); + return absl::make_unique<MappingValueIterator>(nested); } else if (IsSparseTensorValueType(nested)) { - // if nested is a SparseTensorValue, just return itself as a single item - Py_INCREF(nested); - next_values->emplace_back(nested); - return true; + return absl::make_unique<SparseTensorValueIterator>(nested); + } else { + return absl::make_unique<SequenceValueIterator>(nested); } - // iterate and recurse - return GetNextValuesForIterable(nested, next_values); } bool FlattenHelper( PyObject* nested, PyObject* list, const std::function<int(PyObject*)>& is_sequence_helper, - const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>& - next_values_getter) { + const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) { // if nested is not a sequence, append itself and exit int is_seq = is_sequence_helper(nested); if (is_seq == -1) return false; @@ -416,16 +401,15 @@ bool FlattenHelper( return PyList_Append(list, nested) != -1; } - std::vector<Safe_PyObjectPtr> next_values; - // Get the next values to recurse over. - if (!next_values_getter(nested, &next_values)) return false; + ValueIteratorPtr iter = value_iterator_getter(nested); + if (!iter->valid()) return false; - for (const auto& item : next_values) { + for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) { if (Py_EnterRecursiveCall(" in flatten")) { return false; } - const bool success = - FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter); + const bool success = FlattenHelper(item.get(), list, is_sequence_helper, + value_iterator_getter); Py_LeaveRecursiveCall(); if (!success) { return false; @@ -579,22 +563,25 @@ bool AssertSameStructureHelper( } } - ValIterator iter1(o1); - ValIterator iter2(o2); + ValueIteratorPtr iter1 = GetValueIterator(o1); + ValueIteratorPtr iter2 = GetValueIterator(o2); + + if (!iter1->valid() || !iter2->valid()) return false; while (true) { - PyObject* v1 = iter1.next(); - PyObject* v2 = iter2.next(); - if (v1 != nullptr && v2 != nullptr) { + Safe_PyObjectPtr v1 = iter1->next(); + Safe_PyObjectPtr v2 = iter2->next(); + if (v1 && v2) { if (Py_EnterRecursiveCall(" in assert_same_structure")) { return false; } - bool no_internal_errors = AssertSameStructureHelper( - v1, v2, check_types, error_msg, is_type_error, is_sequence_helper); + bool no_internal_errors = + AssertSameStructureHelper(v1.get(), v2.get(), check_types, error_msg, + is_type_error, is_sequence_helper); Py_LeaveRecursiveCall(); if (!no_internal_errors) return false; if (!error_msg->empty()) return true; - } else if (v1 == nullptr && v2 == nullptr) { + } else if (!v1 && !v2) { // Done with all recursive calls. Structure matched. return true; } else { @@ -655,7 +642,7 @@ bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); - if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) { + if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) { return list; } else { Py_DECREF(list); @@ -668,7 +655,7 @@ bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; } PyObject* FlattenForData(PyObject* nested) { PyObject* list = PyList_New(0); if (FlattenHelper(nested, list, IsSequenceForDataHelper, - GetNextValuesForData)) { + GetValueIteratorForData)) { return list; } else { Py_DECREF(list); |