aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-02-08 13:38:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 13:45:42 -0800
commit597377fca28c76306e749f78f8073f55726d54c9 (patch)
treeae29dcb3e422a127a403647fc0731d03c338b760
parent166d6e869e4c559829c8ad4d7cc19a792c2bf444 (diff)
Prototype object-based save/restore syntax sugar
- Overrides __setattr__ to allow implicit dependencies - Supports any valid Python 2 identifier as a Checkpointable dependency name (was messing with underscore prefixes) PiperOrigin-RevId: 185044022
-rw-r--r--tensorflow/contrib/eager/python/checkpointable.py50
-rw-r--r--tensorflow/contrib/eager/python/checkpointable_test.py94
2 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py
index 47ce5897c0..ce4e07874e 100644
--- a/tensorflow/contrib/eager/python/checkpointable.py
+++ b/tensorflow/contrib/eager/python/checkpointable.py
@@ -46,9 +46,10 @@ _CheckpointableReference = collections.namedtuple(
])
# Validation regular expression for the local names of Checkpointable
-# objects. In particular, disallows "/" in names, and reserves
-# underscore-prefixed names.
-_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$")
+# objects. In particular, disallows "/" in names, and reserves dash-prefixed
+# names (which are not valid Python identifiers, so we're not restricting the
+# __setattr__ syntax that way).
+_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9_.][A-Za-z0-9_.-]*$")
# Keyword for identifying that the next bit of a checkpoint variable name is a
# slot name. May not be the local name of a checkpointable. Checkpoint names for
@@ -58,7 +59,7 @@ _VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$")
#
# Where <path to variable> is a full path from the checkpoint root to the
# variable being slotted for.
-_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT"
+_OPTIMIZER_SLOTS_NAME = "-OPTIMIZER_SLOT"
def _assign_existing_variable(variable_to_restore, value_pointer):
@@ -89,12 +90,12 @@ class Checkpointable(object):
"""
def __init__(self):
- # Basically a less useful OrderedDict but without the reference cycles.
- # TODO(allenl): Switch this to OrderedDict once TensorFlow supports only
- # Python 3.6+.
# A list of _CheckpointableReference objects.
self._checkpoint_dependencies = []
- self._dependency_names = set()
+ # Maps names -> Checkpointable objects for named dependencies
+ self._dependency_names = {}
+ # Set of all tracked Checkpointables
+ self._already_tracked = set()
# Start numbering at 1, since an un-set protocol buffer integer is
# indistinguishable from 0.
self._next_unnamed_checkpoint_dependency_uid = 1
@@ -102,6 +103,27 @@ class Checkpointable(object):
self._deferred_restorations = {} # local name -> _VariableRestoration
# object
+ def __setattr__(self, name, value):
+ """Support self.foo = checkpointable syntax.
+
+ `self.foo = checkpointable` is equivalent to
+ `self.foo = self.track_checkpointable(checkpointable, name='foo')`.
+
+ No new tracking if `value` is not a `Checkpointable`, or if `value` is
+ already being tracked (either because of an explicit `track_checkpointable`
+ or a previous `__setattr__`).
+
+ Args:
+ name: The name of the property being set.
+ value: The new value for the property.
+ """
+ # Give child classes (e.g. Network) priority, then track only if the object
+ # hasn't been added to _already_tracked.
+ super(Checkpointable, self).__setattr__(name, value)
+ if (isinstance(value, Checkpointable)
+ and value not in self._already_tracked):
+ self.track_checkpointable(value, name=name)
+
def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs):
"""Create a new variable object to be saved with this `Checkpointable`.
@@ -217,12 +239,13 @@ class Checkpointable(object):
("Checkpointable names must match the regular expression '%s', but "
"got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern,
name))
- if name in self._dependency_names:
+ if (name in self._dependency_names
+ and self._dependency_names[name] is not checkpointable):
raise ValueError(
("Called Checkpointable.track_checkpointable() with name='%s', but "
"a Checkpointable with this name is already declared as a "
"dependency. If provided, names must be unique.") % (name,))
- self._dependency_names.add(name)
+ self._dependency_names[name] = checkpointable
local_uid = None
else:
# TODO(allenl): Should this be exposed to allow users to stop depending on
@@ -232,6 +255,7 @@ class Checkpointable(object):
self._checkpoint_dependencies.append(
_CheckpointableReference(
name=name, ref=checkpointable, local_uid=local_uid))
+ self._already_tracked.add(checkpointable)
return checkpointable
def _process_restoration(self, restoration):
@@ -305,8 +329,10 @@ def _breadth_first_checkpointable_traversal(root_checkpointable):
def _object_prefix_from_path(path_to_root):
- return "/".join((checkpointable.name if checkpointable.name else "_%d" % (
- checkpointable.local_uid,)) for checkpointable in path_to_root)
+ return "/".join(
+ (checkpointable.name if checkpointable.name
+ else "-unnamed_%d" % (checkpointable.local_uid,))
+ for checkpointable in path_to_root)
def _escape_variable_name(variable_name):
diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py
index d823053283..4b92ad59e7 100644
--- a/tensorflow/contrib/eager/python/checkpointable_test.py
+++ b/tensorflow/contrib/eager/python/checkpointable_test.py
@@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.layers import base
from tensorflow.python.layers import core
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
@@ -64,6 +65,13 @@ class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable):
network_lib.Network.__init__(self)
checkpointable.Checkpointable.__init__(self)
+ def __setattr__(self, name, value):
+ if isinstance(value, base.Layer) and value not in self._already_tracked:
+ self.track_layer(value, name=name)
+ # Checkpointable is next in the method resolution order, so this will catch
+ # Checkpointable objects which aren't Layers.
+ super(CheckpointableNetwork, self).__setattr__(name, value)
+
def track_layer(self, layer, name=None):
self.track_checkpointable(layer, name=name)
return super(CheckpointableNetwork, self).track_layer(layer)
@@ -107,18 +115,34 @@ class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable):
return v
+class NonLayerCheckpointable(checkpointable.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ with variable_scope.variable_scope(None, default_name="non_layer"):
+ # Unfortunately using tf.get_variable to implement self.add_variable
+ # (necessary for backwards compatibile naming with Layers) we can still
+ # run into duplicate variable errors (when building a graph, not when
+ # executing eagerly), thus the variable scope.
+ #
+ # TODO(allenl): Consider creating a ResourceVariable directly by
+ # default so that variable reuse isn't an issue.
+ self._a_variable = self.add_variable("a_variable", shape=[])
+
+
class MyNetwork(CheckpointableNetwork):
"""A concrete Network for testing."""
def __init__(self):
super(MyNetwork, self).__init__()
- self._named = self.track_layer(
- CheckpointableDenseLayer(1, use_bias=True), name="named_dense")
+ self._named_dense = CheckpointableDenseLayer(1, use_bias=True)
self._unnamed = self.track_layer(
CheckpointableDenseLayer(1, use_bias=False))
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
def call(self, values):
- return self._unnamed(self._named(values))
+ return self._unnamed(self._named_dense(values))
class Root(checkpointable.Checkpointable):
@@ -126,8 +150,8 @@ class Root(checkpointable.Checkpointable):
def __init__(self, optimizer, network):
super(Root, self).__init__()
- self.track_checkpointable(optimizer, name="optimizer")
- self.track_checkpointable(network, name="network")
+ self._optimizer = optimizer
+ self._network = self.track_checkpointable(network, "network")
self._global_step = None
@property
@@ -177,36 +201,38 @@ class CheckpointNamingTests(test.TestCase):
"global_step",
# No name provided to track_checkpointable(), so the position is used
# instead (one-based).
- "network/_1/kernel",
+ "network/-unnamed_1/kernel",
# track_checkpointable() with a name provided, so that's used
- "network/named_dense/kernel",
- "network/named_dense/bias",
+ "network/_named_dense/kernel",
+ "network/_named_dense/bias",
+ # non-Layer dependency of the network
+ "network/_non_layer/a_variable",
# The optimizer creates two non-slot variables
- "optimizer/beta1_power",
- "optimizer/beta2_power",
+ "_optimizer/beta1_power",
+ "_optimizer/beta2_power",
# Slot variables
- "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m",
- "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v",
- "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m",
- "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v",
- "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m",
- "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v",
+ "network/-unnamed_1/kernel/-OPTIMIZER_SLOT/_optimizer/m",
+ "network/-unnamed_1/kernel/-OPTIMIZER_SLOT/_optimizer/v",
+ "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m",
+ "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v",
+ "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m",
+ "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v",
)
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual("global_step:0", named_variables["global_step"].name)
self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0",
- named_variables["network/_1/kernel"].name)
+ named_variables["network/-unnamed_1/kernel"].name)
self.assertEqual("my_network/checkpointable_dense_layer/kernel:0",
- named_variables["network/named_dense/kernel"].name)
+ named_variables["network/_named_dense/kernel"].name)
self.assertEqual("beta1_power:0",
- named_variables["optimizer/beta1_power"].name)
+ named_variables["_optimizer/beta1_power"].name)
self.assertEqual("beta2_power:0",
- named_variables["optimizer/beta2_power"].name)
+ named_variables["_optimizer/beta2_power"].name)
# Spot check the generated protocol buffers.
self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid)
- self.assertEqual("optimizer",
+ self.assertEqual("_optimizer",
serialized_graph.nodes[0].children[0].local_name)
optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
0].node_id]
@@ -217,18 +243,18 @@ class CheckpointNamingTests(test.TestCase):
"bias", optimizer_node.slot_variables[0].original_variable_local_name)
original_variable_owner = serialized_graph.nodes[
optimizer_node.slot_variables[0].original_variable_node_id]
- self.assertEqual("network/named_dense/bias",
+ self.assertEqual("network/_named_dense/bias",
original_variable_owner.variables[0].checkpoint_key)
self.assertEqual("bias", original_variable_owner.variables[0].local_name)
self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
- self.assertEqual("network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m",
+ self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m",
optimizer_node.slot_variables[0].checkpoint_key)
# We strip off the :0 suffix, as variable.name-based saving does.
self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam",
optimizer_node.slot_variables[0].full_name)
self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0",
optimizer.get_slot(
- var=named_variables["network/named_dense/bias"],
+ var=named_variables["network/_named_dense/bias"],
name="m").name)
@test_util.run_in_graph_and_eager_modes()
@@ -247,14 +273,14 @@ class CheckpointNamingTests(test.TestCase):
self.evaluate(variables.global_variables_initializer())
self.evaluate(train_op)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
- self.evaluate(state_ops.assign(network._named.variables[1], [42.]))
- m_bias_slot = optimizer.get_slot(network._named.variables[1], "m")
+ self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.]))
+ m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m")
self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
serialized_graph, save_path = checkpointable.save(
file_prefix=prefix,
root_checkpointable=root_checkpointable,
global_step=root_checkpointable.global_step)
- self.evaluate(state_ops.assign(network._named.variables[1], [43.]))
+ self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.]))
self.evaluate(state_ops.assign(root_checkpointable.global_step, 3))
optimizer_variables = self.evaluate(optimizer.variables())
self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
@@ -263,7 +289,7 @@ class CheckpointNamingTests(test.TestCase):
save_path=save_path,
root_checkpointable=root_checkpointable,
object_graph_proto=serialized_graph)
- self.assertAllEqual([42.], self.evaluate(network._named.variables[1]))
+ self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1]))
self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step))
self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
with ops.Graph().as_default():
@@ -281,9 +307,9 @@ class CheckpointNamingTests(test.TestCase):
self.assertAllEqual(1, self.evaluate(on_create_root.global_step))
self.assertAllEqual([42.],
self.evaluate(
- on_create_network._named.variables[1]))
+ on_create_network._named_dense.variables[1]))
on_create_m_bias_slot = on_create_optimizer.get_slot(
- on_create_network._named.variables[1], "m")
+ on_create_network._named_dense.variables[1], "m")
# Optimizer slot variables are created when the original variable is
# restored.
self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
@@ -393,16 +419,16 @@ class CheckpointNamingTests(test.TestCase):
leaf.add_variable(name="v", shape=[])
named_variables, _ = checkpointable._serialize_object_graph(root)
variable_name, = named_variables.keys()
- self.assertEqual(r"_1/v", variable_name)
+ self.assertEqual(r"-unnamed_1/v", variable_name)
@test_util.run_in_graph_and_eager_modes()
def testLocalNameValidation(self):
root = checkpointable.Checkpointable()
leaf = checkpointable.Checkpointable()
with self.assertRaisesRegexp(ValueError, "invalid name"):
- # Leading underscores are reserved, which avoids conflicts with
- # un-named edges in paths and the optimizer slots identifier.
- root.track_checkpointable(leaf, name="_12")
+ # Leading dashes are reserved, which avoids conflicts with un-named edges
+ # in paths and the optimizer slots identifier.
+ root.track_checkpointable(leaf, name="-unnamed-12")
if __name__ == "__main__":