aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/util
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-06-29 14:02:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-29 14:04:49 -0700
commitdcaa037571ab0933977f70574f4f78875155ae20 (patch)
tree4968e1966ca334f42296beae6cb1ecd8d483215e /tensorflow/python/util
parentb3c163a754574faed4337f869c2d650a9f45c09c (diff)
Auto tracking for Python lists assigned to attributes of Model/Checkpointable
Conceptually lists just get replaced with a list-like wrapper. A shallow copy is maintained for error checking (since appends to it aren't monitored, we can't do restore-on-create for variables unless it's being modified through the wrapper). There are lots of other details. I gave up on generalizing our isinstance(obj, (list, tuple)) checks and just subclassed list. Behaving like a list means the type should be unhashable, which requires some workarounds when we're collecting objects (object-identity collections, and object-identity versions of weak reference containers). Adds a decorator for exempting whole methods from automatic dependency tracking so we don't need to track down every last self.inputs = [] statement to avoid polluting dependencies. There's a TODO for tuples and dictionaries. PiperOrigin-RevId: 202703271
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r--tensorflow/python/util/nest.py11
-rw-r--r--tensorflow/python/util/util.cc6
2 files changed, 12 insertions, 5 deletions
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 1104768ae8..d63f59a8c8 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -167,11 +167,14 @@ def assert_same_structure(nest1, nest2, check_types=True):
Args:
nest1: an arbitrarily nested structure.
nest2: an arbitrarily nested structure.
- check_types: if `True` (default) types of sequences are checked as
- well, including the keys of dictionaries. If set to `False`, for example
- a list and a tuple of objects will look the same if they have the same
+ check_types: if `True` (default) types of sequences are checked as well,
+ including the keys of dictionaries. If set to `False`, for example a
+ list and a tuple of objects will look the same if they have the same
size. Note that namedtuples with identical name and fields are always
- considered to have the same shallow structure.
+ considered to have the same shallow structure. Two types will also be
+ considered the same if they are both list subtypes (which allows "list"
+ and "_ListWrapper" from checkpointable dependency tracking to compare
+ equal).
Raises:
ValueError: If the two structures do not have the same number of elements or
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index c79d8a8445..366f8a0deb 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -394,7 +394,11 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
type2->tp_name);
return true;
}
- } else if (type1 != type2) {
+ } else if (type1 != type2
+ /* If both sequences are list types, don't complain. This allows
+ one to be a list subclass (e.g. _ListWrapper used for automatic
+ dependency tracking.) */
+ && !(PyList_Check(o1) && PyList_Check(o2))) {
*is_type_error = true;
*error_msg = tensorflow::strings::StrCat(
"The two namedtuples don't have the same sequence type. "