aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/util/nest.py16
-rw-r--r--tensorflow/python/util/nest_test.py34
-rw-r--r--tensorflow/python/util/util.cc51
-rw-r--r--tensorflow/python/util/util.h9
-rw-r--r--tensorflow/python/util/util.i12
5 files changed, 121 insertions, 1 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 653ca525dc..758cba7487 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
+attr.s decorated classes (http://www.attrs.org) are also supported, in the
+same way as `namedtuple`.
+
The utilities here assume (and do not check) that the nested structures form a
'tree', i.e., no references in the structure of the input of these functions
should be recursive.
@@ -38,6 +41,12 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, "__attrs_attrs__")
+ return [getattr(obj, a.name) for a in attrs]
+
+
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
@@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_tensorflow.IsMapping
+_is_attrs = _pywrap_tensorflow.IsAttrs
def _sequence_like(instance, args):
@@ -85,7 +95,7 @@ def _sequence_like(instance, args):
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
- elif _is_namedtuple(instance):
+ elif _is_namedtuple(instance) or _is_attrs(instance):
return type(instance)(*args)
else:
# Not a namedtuple
@@ -93,6 +103,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
+ """Yields the next value from the given iterable."""
if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
@@ -101,6 +112,9 @@ def _yield_value(iterable):
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
+ elif _is_attrs(iterable):
+ for value in _get_attrs_values(iterable):
+ yield value
else:
for value in iterable:
yield value
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index bfb4c6f910..e03a8daaa1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
class _CustomMapping(collections.Mapping):
@@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+ if attr:
+ class BadAttr(object):
+ """Class that has a non-iterable __attrs_attrs__."""
+ __attrs_attrs__ = None
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrsFlattenAndPack(self):
+ if attr is None:
+ self.skipTest("attr module is unavailable.")
+
+ field_values = [1, 2]
+ sample_attr = NestTest.SampleAttr(*field_values)
+ self.assertFalse(nest._is_attrs(field_values))
+ self.assertTrue(nest._is_attrs(sample_attr))
+ flat = nest.flatten(sample_attr)
+ self.assertEqual(field_values, flat)
+ restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ self.assertEqual(restructured_from_flat, sample_attr)
+
+ # Check that flatten fails if attributes are not iterable
+ with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ flat = nest.flatten(NestTest.BadAttr())
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 2087957b31..38b8491c66 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -192,6 +192,19 @@ int IsMappingHelper(PyObject* o) {
return check_cache->CachedLookup(o);
}
+// Returns 1 if `o` is an instance of attrs-decorated class.
+// Returns 0 otherwise.
+int IsAttrsHelper(PyObject* o) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ } else {
+ // PyObject_GetAttrString returns null on error
+ PyErr_Clear();
+ return 0;
+ }
+}
+
// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
// Returns 0 otherwise.
// Returns -1 if an error occurred.
@@ -206,6 +219,7 @@ int IsSequenceHelper(PyObject* o) {
});
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
+ if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, "
@@ -354,6 +368,38 @@ class SparseTensorValueIterator : public ValueIterator {
Safe_PyObjectPtr tensor_;
};
+class AttrsValueIterator : public ValueIterator {
+ public:
+ explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
+ Py_INCREF(nested);
+ cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
+ if (cls_) {
+ attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
+ if (attrs_) {
+ iter_.reset(PyObject_GetIter(attrs_.get()));
+ }
+ }
+ if (!iter_ || PyErr_Occurred()) invalidate();
+ }
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
+ if (item) {
+ Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
+ result.reset(PyObject_GetAttr(nested_.get(), name.get()));
+ }
+
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr nested_;
+ Safe_PyObjectPtr cls_;
+ Safe_PyObjectPtr attrs_;
+ Safe_PyObjectPtr iter_;
+};
+
bool IsSparseTensorValueType(PyObject* o) {
if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
return false;
@@ -372,6 +418,8 @@ ValueIteratorPtr GetValueIterator(PyObject* nested) {
return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
} else {
return absl::make_unique<SequenceValueIterator>(nested);
}
@@ -383,6 +431,8 @@ ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
return absl::make_unique<SparseTensorValueIterator>(nested);
} else {
@@ -639,6 +689,7 @@ void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
+bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 343605285e..01f85ea1dc 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -56,6 +56,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);
+// Returns a true if its input is an instance of an attr.s decorated class.
+//
+// Args:
+// o: the input to be checked.
+//
+// Returns:
+// True if the object is an instance of an attr.s decorated class.
+bool IsAttrs(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 104a615636..32a6e684fa 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -65,6 +65,18 @@ Returns:
%unignore tensorflow::swig::IsMapping;
%noexception tensorflow::swig::IsMapping;
+%feature("docstring") tensorflow::swig::IsAttrs
+"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is an instance of an `attr.s` decorated class.
+"""
+%unignore tensorflow::swig::IsAttrs;
+%noexception tensorflow::swig::IsAttrs;
+
%feature("docstring") tensorflow::swig::SameNamedtuples
"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;