diff options
author | 2018-02-08 13:38:50 -0800 | |
---|---|---|
committer | 2018-02-08 13:45:42 -0800 | |
commit | 597377fca28c76306e749f78f8073f55726d54c9 (patch) | |
tree | ae29dcb3e422a127a403647fc0731d03c338b760 | |
parent | 166d6e869e4c559829c8ad4d7cc19a792c2bf444 (diff) |
Prototype object-based save/restore syntax sugar
- Overrides __setattr__ to allow implicit dependencies
- Supports any valid Python 2 identifier as a Checkpointable dependency name (was messing with underscore prefixes)
PiperOrigin-RevId: 185044022
-rw-r--r-- | tensorflow/contrib/eager/python/checkpointable.py | 50 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/checkpointable_test.py | 94 |
2 files changed, 98 insertions, 46 deletions
diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py index 47ce5897c0..ce4e07874e 100644 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -46,9 +46,10 @@ _CheckpointableReference = collections.namedtuple( ]) # Validation regular expression for the local names of Checkpointable -# objects. In particular, disallows "/" in names, and reserves -# underscore-prefixed names. -_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") +# objects. In particular, disallows "/" in names, and reserves dash-prefixed +# names (which are not valid Python identifiers, so we're not restricting the +# __setattr__ syntax that way). +_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9_.][A-Za-z0-9_.-]*$") # Keyword for identifying that the next bit of a checkpoint variable name is a # slot name. May not be the local name of a checkpointable. Checkpoint names for @@ -58,7 +59,7 @@ _VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9.][A-Za-z0-9_.-]*$") # # Where <path to variable> is a full path from the checkpoint root to the # variable being slotted for. -_OPTIMIZER_SLOTS_NAME = "_OPTIMIZER_SLOT" +_OPTIMIZER_SLOTS_NAME = "-OPTIMIZER_SLOT" def _assign_existing_variable(variable_to_restore, value_pointer): @@ -89,12 +90,12 @@ class Checkpointable(object): """ def __init__(self): - # Basically a less useful OrderedDict but without the reference cycles. - # TODO(allenl): Switch this to OrderedDict once TensorFlow supports only - # Python 3.6+. # A list of _CheckpointableReference objects. self._checkpoint_dependencies = [] - self._dependency_names = set() + # Maps names -> Checkpointable objects for named dependencies + self._dependency_names = {} + # Set of all tracked Checkpointables + self._already_tracked = set() # Start numbering at 1, since an un-set protocol buffer integer is # indistinguishable from 0. self._next_unnamed_checkpoint_dependency_uid = 1 @@ -102,6 +103,27 @@ class Checkpointable(object): self._deferred_restorations = {} # local name -> _VariableRestoration # object + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax. + + `self.foo = checkpointable` is equivalent to + `self.foo = self.track_checkpointable(checkpointable, name='foo')`. + + No new tracking if `value` is not a `Checkpointable`, or if `value` is + already being tracked (either because of an explicit `track_checkpointable` + or a previous `__setattr__`). + + Args: + name: The name of the property being set. + value: The new value for the property. + """ + # Give child classes (e.g. Network) priority, then track only if the object + # hasn't been added to _already_tracked. + super(Checkpointable, self).__setattr__(name, value) + if (isinstance(value, Checkpointable) + and value not in self._already_tracked): + self.track_checkpointable(value, name=name) + def add_variable(self, name, shape, dtype=None, initializer=None, **kwargs): """Create a new variable object to be saved with this `Checkpointable`. @@ -217,12 +239,13 @@ class Checkpointable(object): ("Checkpointable names must match the regular expression '%s', but " "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, name)) - if name in self._dependency_names: + if (name in self._dependency_names + and self._dependency_names[name] is not checkpointable): raise ValueError( ("Called Checkpointable.track_checkpointable() with name='%s', but " "a Checkpointable with this name is already declared as a " "dependency. If provided, names must be unique.") % (name,)) - self._dependency_names.add(name) + self._dependency_names[name] = checkpointable local_uid = None else: # TODO(allenl): Should this be exposed to allow users to stop depending on @@ -232,6 +255,7 @@ class Checkpointable(object): self._checkpoint_dependencies.append( _CheckpointableReference( name=name, ref=checkpointable, local_uid=local_uid)) + self._already_tracked.add(checkpointable) return checkpointable def _process_restoration(self, restoration): @@ -305,8 +329,10 @@ def _breadth_first_checkpointable_traversal(root_checkpointable): def _object_prefix_from_path(path_to_root): - return "/".join((checkpointable.name if checkpointable.name else "_%d" % ( - checkpointable.local_uid,)) for checkpointable in path_to_root) + return "/".join( + (checkpointable.name if checkpointable.name + else "-unnamed_%d" % (checkpointable.local_uid,)) + for checkpointable in path_to_root) def _escape_variable_name(variable_name): diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py index d823053283..4b92ad59e7 100644 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -29,6 +29,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.layers import base from tensorflow.python.layers import core from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops @@ -64,6 +65,13 @@ class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): network_lib.Network.__init__(self) checkpointable.Checkpointable.__init__(self) + def __setattr__(self, name, value): + if isinstance(value, base.Layer) and value not in self._already_tracked: + self.track_layer(value, name=name) + # Checkpointable is next in the method resolution order, so this will catch + # Checkpointable objects which aren't Layers. + super(CheckpointableNetwork, self).__setattr__(name, value) + def track_layer(self, layer, name=None): self.track_checkpointable(layer, name=name) return super(CheckpointableNetwork, self).track_layer(layer) @@ -107,18 +115,34 @@ class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): return v +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + with variable_scope.variable_scope(None, default_name="non_layer"): + # Unfortunately using tf.get_variable to implement self.add_variable + # (necessary for backwards compatibile naming with Layers) we can still + # run into duplicate variable errors (when building a graph, not when + # executing eagerly), thus the variable scope. + # + # TODO(allenl): Consider creating a ResourceVariable directly by + # default so that variable reuse isn't an issue. + self._a_variable = self.add_variable("a_variable", shape=[]) + + class MyNetwork(CheckpointableNetwork): """A concrete Network for testing.""" def __init__(self): super(MyNetwork, self).__init__() - self._named = self.track_layer( - CheckpointableDenseLayer(1, use_bias=True), name="named_dense") + self._named_dense = CheckpointableDenseLayer(1, use_bias=True) self._unnamed = self.track_layer( CheckpointableDenseLayer(1, use_bias=False)) + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() def call(self, values): - return self._unnamed(self._named(values)) + return self._unnamed(self._named_dense(values)) class Root(checkpointable.Checkpointable): @@ -126,8 +150,8 @@ class Root(checkpointable.Checkpointable): def __init__(self, optimizer, network): super(Root, self).__init__() - self.track_checkpointable(optimizer, name="optimizer") - self.track_checkpointable(network, name="network") + self._optimizer = optimizer + self._network = self.track_checkpointable(network, "network") self._global_step = None @property @@ -177,36 +201,38 @@ class CheckpointNamingTests(test.TestCase): "global_step", # No name provided to track_checkpointable(), so the position is used # instead (one-based). - "network/_1/kernel", + "network/-unnamed_1/kernel", # track_checkpointable() with a name provided, so that's used - "network/named_dense/kernel", - "network/named_dense/bias", + "network/_named_dense/kernel", + "network/_named_dense/bias", + # non-Layer dependency of the network + "network/_non_layer/a_variable", # The optimizer creates two non-slot variables - "optimizer/beta1_power", - "optimizer/beta2_power", + "_optimizer/beta1_power", + "_optimizer/beta2_power", # Slot variables - "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/m", - "network/_1/kernel/_OPTIMIZER_SLOT/optimizer/v", - "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/m", - "network/named_dense/kernel/_OPTIMIZER_SLOT/optimizer/v", - "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", - "network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/v", + "network/-unnamed_1/kernel/-OPTIMIZER_SLOT/_optimizer/m", + "network/-unnamed_1/kernel/-OPTIMIZER_SLOT/_optimizer/v", + "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m", + "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v", + "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", + "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v", ) six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual("global_step:0", named_variables["global_step"].name) self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/_1/kernel"].name) + named_variables["network/-unnamed_1/kernel"].name) self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/named_dense/kernel"].name) + named_variables["network/_named_dense/kernel"].name) self.assertEqual("beta1_power:0", - named_variables["optimizer/beta1_power"].name) + named_variables["_optimizer/beta1_power"].name) self.assertEqual("beta2_power:0", - named_variables["optimizer/beta2_power"].name) + named_variables["_optimizer/beta2_power"].name) # Spot check the generated protocol buffers. self.assertEqual(0, serialized_graph.nodes[0].children[0].local_uid) - self.assertEqual("optimizer", + self.assertEqual("_optimizer", serialized_graph.nodes[0].children[0].local_name) optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ 0].node_id] @@ -217,18 +243,18 @@ class CheckpointNamingTests(test.TestCase): "bias", optimizer_node.slot_variables[0].original_variable_local_name) original_variable_owner = serialized_graph.nodes[ optimizer_node.slot_variables[0].original_variable_node_id] - self.assertEqual("network/named_dense/bias", + self.assertEqual("network/_named_dense/bias", original_variable_owner.variables[0].checkpoint_key) self.assertEqual("bias", original_variable_owner.variables[0].local_name) self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) - self.assertEqual("network/named_dense/bias/_OPTIMIZER_SLOT/optimizer/m", + self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", optimizer_node.slot_variables[0].checkpoint_key) # We strip off the :0 suffix, as variable.name-based saving does. self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", optimizer_node.slot_variables[0].full_name) self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", optimizer.get_slot( - var=named_variables["network/named_dense/bias"], + var=named_variables["network/_named_dense/bias"], name="m").name) @test_util.run_in_graph_and_eager_modes() @@ -247,14 +273,14 @@ class CheckpointNamingTests(test.TestCase): self.evaluate(variables.global_variables_initializer()) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(state_ops.assign(network._named.variables[1], [42.])) - m_bias_slot = optimizer.get_slot(network._named.variables[1], "m") + self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") self.evaluate(state_ops.assign(m_bias_slot, [1.5])) serialized_graph, save_path = checkpointable.save( file_prefix=prefix, root_checkpointable=root_checkpointable, global_step=root_checkpointable.global_step) - self.evaluate(state_ops.assign(network._named.variables[1], [43.])) + self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) optimizer_variables = self.evaluate(optimizer.variables()) self.evaluate(state_ops.assign(m_bias_slot, [-2.])) @@ -263,7 +289,7 @@ class CheckpointNamingTests(test.TestCase): save_path=save_path, root_checkpointable=root_checkpointable, object_graph_proto=serialized_graph) - self.assertAllEqual([42.], self.evaluate(network._named.variables[1])) + self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) with ops.Graph().as_default(): @@ -281,9 +307,9 @@ class CheckpointNamingTests(test.TestCase): self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) self.assertAllEqual([42.], self.evaluate( - on_create_network._named.variables[1])) + on_create_network._named_dense.variables[1])) on_create_m_bias_slot = on_create_optimizer.get_slot( - on_create_network._named.variables[1], "m") + on_create_network._named_dense.variables[1], "m") # Optimizer slot variables are created when the original variable is # restored. self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) @@ -393,16 +419,16 @@ class CheckpointNamingTests(test.TestCase): leaf.add_variable(name="v", shape=[]) named_variables, _ = checkpointable._serialize_object_graph(root) variable_name, = named_variables.keys() - self.assertEqual(r"_1/v", variable_name) + self.assertEqual(r"-unnamed_1/v", variable_name) @test_util.run_in_graph_and_eager_modes() def testLocalNameValidation(self): root = checkpointable.Checkpointable() leaf = checkpointable.Checkpointable() with self.assertRaisesRegexp(ValueError, "invalid name"): - # Leading underscores are reserved, which avoids conflicts with - # un-named edges in paths and the optimizer slots identifier. - root.track_checkpointable(leaf, name="_12") + # Leading dashes are reserved, which avoids conflicts with un-named edges + # in paths and the optimizer slots identifier. + root.track_checkpointable(leaf, name="-unnamed-12") if __name__ == "__main__": |