aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util/util.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/util/util.cc')
-rw-r--r--tensorflow/python/util/util.cc305
1 files changed, 146 insertions, 159 deletions
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 562bbdcfeb..2087957b31 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/python/util/util.h"
#include <functional>
+#include <memory>
#include <unordered_map>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -222,93 +224,136 @@ int IsSequenceHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
-// Implements the same idea as tensorflow.util.nest._yield_value
-// During construction we check if the iterable is a dictionary.
-// If so, we construct a sequence from its sorted keys that will be used
-// for iteration.
-// If not, we construct a sequence directly from the iterable.
-// At each step, we get the next element from the sequence and use it
-// either as a key or return it directly.
-//
-// 'iterable' must not be modified while ValIterator is used.
-class ValIterator {
+// ValueIterator interface
+class ValueIterator {
+ public:
+ virtual ~ValueIterator() {}
+ virtual Safe_PyObjectPtr next() = 0;
+
+ bool valid() const { return is_valid_; }
+
+ protected:
+ void invalidate() { is_valid_ = false; }
+
+ private:
+ bool is_valid_ = true;
+};
+
+using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
+
+// Iterate through dictionaries in a deterministic order by sorting the
+// keys. Notice this means that we ignore the original order of
+// `OrderedDict` instances. This is intentional, to avoid potential
+// bugs caused by mixing ordered and plain dicts (e.g., flattening
+// a dict but using a corresponding `OrderedDict` to pack it back).
+class DictValueIterator : public ValueIterator {
public:
- explicit ValIterator(PyObject* iterable)
- : dict_(nullptr),
- mapping_(nullptr),
- last_mapping_element_(nullptr),
- seq_(nullptr),
- index_(0) {
- if (PyDict_Check(iterable)) {
- dict_ = iterable;
- // PyDict_Keys returns a list, which can be used with
- // PySequence_Fast_GET_ITEM.
- seq_ = PyDict_Keys(iterable);
- // Iterate through dictionaries in a deterministic order by sorting the
- // keys. Notice this means that we ignore the original order of
- // `OrderedDict` instances. This is intentional, to avoid potential
- // bugs caused by mixing ordered and plain dicts (e.g., flattening
- // a dict but using a corresponding `OrderedDict` to pack it back).
- PyList_Sort(seq_);
- } else if (IsMappingHelper(iterable)) {
- mapping_ = iterable;
- seq_ = MappingKeys(iterable);
- PyList_Sort(seq_);
+ explicit DictValueIterator(PyObject* dict)
+ : dict_(dict), keys_(PyDict_Keys(dict)) {
+ if (PyList_Sort(keys_.get()) == -1) {
+ invalidate();
} else {
- seq_ = PySequence_Fast(iterable, "");
+ iter_.reset(PyObject_GetIter(keys_.get()));
}
- size_ = PySequence_Fast_GET_SIZE(seq_);
}
- ~ValIterator() { Py_DECREF(seq_); }
-
- // Return a borrowed reference to the next element from iterable.
- // Return nullptr when iteration is over.
- PyObject* next() {
- if (TF_PREDICT_FALSE(seq_ == nullptr)) {
- return nullptr;
- }
- PyObject* element = nullptr;
- if (index_ < size_) {
- // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references. For general mappings, ValIterator keeps a reference to the
- // last retrieved element (and decrefs it before producing the next
- // element) to abstract away the borrowed/new difference.
- element = PySequence_Fast_GET_ITEM(seq_, index_);
- ++index_;
- if (dict_ != nullptr) {
- element = PyDict_GetItem(dict_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Dictionary was modified during iteration over it");
- return nullptr;
- }
- } else if (mapping_ != nullptr) {
- element = PyObject_GetItem(mapping_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Mapping was modified during iteration over it");
- return nullptr;
- }
- last_mapping_element_.reset(element);
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // PyDict_GetItem returns a borrowed reference.
+ PyObject* elem = PyDict_GetItem(dict_, key.get());
+ if (elem) {
+ Py_INCREF(elem);
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
}
}
- return element;
+ return result;
}
private:
- // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
- // General mappings which have custom Python logic
+// Iterate over mapping objects by sorting the keys first
+class MappingValueIterator : public ValueIterator {
+ public:
+ explicit MappingValueIterator(PyObject* mapping)
+ : mapping_(mapping), keys_(MappingKeys(mapping)) {
+ if (!keys_ || PyList_Sort(keys_.get()) == -1) {
+ invalidate();
+ } else {
+ iter_.reset(PyObject_GetIter(keys_.get()));
+ }
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* elem = PyObject_GetItem(mapping_, key.get());
+ if (elem) {
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ }
+ }
+ return result;
+ }
+
+ private:
PyObject* mapping_;
- Safe_PyObjectPtr last_mapping_element_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
+
+// Iterate over a sequence, by index.
+class SequenceValueIterator : public ValueIterator {
+ public:
+ explicit SequenceValueIterator(PyObject* iterable)
+ : seq_(PySequence_Fast(iterable, "")),
+ size_(PySequence_Fast_GET_SIZE(seq_.get())),
+ index_(0) {}
- PyObject* seq_;
- Py_ssize_t size_;
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ if (index_ < size_) {
+ // PySequence_Fast_GET_ITEM returns a borrowed reference.
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
+ ++index_;
+ Py_INCREF(elem);
+ result.reset(elem);
+ }
+
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr seq_;
+ const Py_ssize_t size_;
Py_ssize_t index_;
};
+// Just return itself as a single item.
+class SparseTensorValueIterator : public ValueIterator {
+ public:
+ explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
+ Py_INCREF(tensor);
+ }
+
+ Safe_PyObjectPtr next() override { return std::move(tensor_); }
+
+ private:
+ Safe_PyObjectPtr tensor_;
+};
+
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -322,93 +367,33 @@ int IsSequenceForDataHelper(PyObject* o) {
!IsSparseTensorValueType(o);
}
-bool GetNextValuesForDict(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(PyDict_Keys(nested));
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- // We know that key and item will not be deleted because nested owns
- // a reference to them and callers of flatten must not modify nested
- // while the method is running.
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- PyObject* item = PyDict_GetItem(nested, key);
- Py_INCREF(item);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForMapping(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(MappingKeys(nested));
- if (keys.get() == nullptr) {
- return false;
- }
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
- PyObject* item = PyObject_GetItem(nested, key);
- next_values->emplace_back(item);
- }
- return true;
-}
-
-bool GetNextValuesForIterable(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- PyObject* item;
- PyObject* iterator = PyObject_GetIter(nested);
- if (iterator == nullptr || PyErr_Occurred()) {
- return false;
- }
- while ((item = PyIter_Next(iterator)) != nullptr) {
- next_values->emplace_back(item);
- }
- Py_DECREF(iterator);
- return true;
-}
-
-// GetNextValues returns the values that the FlattenHelper function will recurse
-// over next.
-bool GetNextValues(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
-// Similar to above, just specialized for the functions in the data pacakage.
-bool GetNextValuesForData(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+// Similar to above, just specialized for the functions in the data package.
+ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
- // if nested is a SparseTensorValue, just return itself as a single item
- Py_INCREF(nested);
- next_values->emplace_back(nested);
- return true;
+ return absl::make_unique<SparseTensorValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function<int(PyObject*)>& is_sequence_helper,
- const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>&
- next_values_getter) {
+ const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
@@ -416,16 +401,15 @@ bool FlattenHelper(
return PyList_Append(list, nested) != -1;
}
- std::vector<Safe_PyObjectPtr> next_values;
- // Get the next values to recurse over.
- if (!next_values_getter(nested, &next_values)) return false;
+ ValueIteratorPtr iter = value_iterator_getter(nested);
+ if (!iter->valid()) return false;
- for (const auto& item : next_values) {
+ for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
- const bool success =
- FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter);
+ const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
+ value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
@@ -579,22 +563,25 @@ bool AssertSameStructureHelper(
}
}
- ValIterator iter1(o1);
- ValIterator iter2(o2);
+ ValueIteratorPtr iter1 = GetValueIterator(o1);
+ ValueIteratorPtr iter2 = GetValueIterator(o2);
+
+ if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
- PyObject* v1 = iter1.next();
- PyObject* v2 = iter2.next();
- if (v1 != nullptr && v2 != nullptr) {
+ Safe_PyObjectPtr v1 = iter1->next();
+ Safe_PyObjectPtr v2 = iter2->next();
+ if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
- bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
+ bool no_internal_errors =
+ AssertSameStructureHelper(v1.get(), v2.get(), check_types, error_msg,
+ is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
- } else if (v1 == nullptr && v2 == nullptr) {
+ } else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
@@ -655,7 +642,7 @@ bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
- if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) {
+ if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) {
return list;
} else {
Py_DECREF(list);
@@ -668,7 +655,7 @@ bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
PyObject* FlattenForData(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceForDataHelper,
- GetNextValuesForData)) {
+ GetValueIteratorForData)) {
return list;
} else {
Py_DECREF(list);