From 21d94be838f7a7917f072e5ec0bbe6e3593177b9 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Thu, 26 Jul 2018 17:53:31 -0700 Subject: Simulate eager variable resoration in tf.keras.Model.load_weights when graph building Previously, the first Model build after load_weights (e.g. a predict()) would trigger restore ops, and any variables added later (e.g. slot variables from an added optimizer) would not be restored when graph building. This change makes behavior consistent between eager execution and graph building by running new restore ops as they come in. PiperOrigin-RevId: 206251879 --- tensorflow/python/keras/engine/base_layer.py | 9 ----- tensorflow/python/keras/engine/network.py | 23 ++----------- tensorflow/python/keras/engine/saving_test.py | 41 ++++++++++++++++------- tensorflow/python/keras/engine/training.py | 3 -- tensorflow/python/training/checkpointable/base.py | 2 +- tensorflow/python/training/checkpointable/util.py | 37 ++++++++++++++++++-- 6 files changed, 68 insertions(+), 47 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index b41f6ee03b..7af71f17a9 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -786,17 +786,8 @@ class Layer(checkpointable.CheckpointableBase): if hasattr(self, '_initial_weights') and self._initial_weights is not None: self.set_weights(self._initial_weights) del self._initial_weights - self._post_build_cleanup() return outputs - def _post_build_cleanup(self): - """Hooks to run after all sub-Layers are built.""" - # Note that in addition to Layer.__call__, this method is called by Model - # after building a graph network (which skips __call__). It should be called - # when possible if self.built may have switched from False to True, and is - # idempotent. - pass # No-op for Layers which don't override this method. - def apply(self, inputs, *args, **kwargs): """Apply the layer on a input. diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 752e9963ca..3e7e3e3d21 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import copy -import functools import json import os import weakref @@ -144,10 +143,6 @@ class Network(base_layer.Layer): self._checkpointable_saver = checkpointable_utils.CheckpointableSaver( weakref.ref(self)) - # A zero-argument function which should be called and set back to None as - # soon as the network is built (only applicable to subclassed Models). Runs - # restore operations when graph building. - self._in_progress_restore_finalizer = None @checkpointable.no_automatic_dependency_tracking def _init_graph_network(self, inputs, outputs, name=None): @@ -1423,13 +1418,9 @@ class Network(base_layer.Layer): 'load_weights).') if not context.executing_eagerly(): session = backend.get_session() - finalizer = functools.partial(status.run_restore_ops, session=session) - if self.built: - finalizer() - else: - # Hold on to this status object until the network is built (for - # subclassed Models). Then we'll run restore ops if necessary. - self._in_progress_restore_finalizer = finalizer + # Restore existing variables (if any) immediately, and set up a + # streaming restore for any variables created in the future. + checkpointable_utils.streaming_restore(status=status, session=session) return status if h5py is None: raise ImportError( @@ -1447,14 +1438,6 @@ class Network(base_layer.Layer): else: saving.load_weights_from_hdf5_group(f, self.layers) - def _post_build_cleanup(self): - super(Network, self)._post_build_cleanup() - if self._in_progress_restore_finalizer is not None: - # Runs queued restore operations left over from load_weights when graph - # building. - self._in_progress_restore_finalizer() - self._in_progress_restore_finalizer = None - def _updated_config(self): """Util shared between different serialization methods. diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py index 030328f2a6..e029e614e0 100644 --- a/tensorflow/python/keras/engine/saving_test.py +++ b/tensorflow/python/keras/engine/saving_test.py @@ -722,18 +722,23 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self.assertEqual(len(graph.get_operations()), op_count) def _weight_loading_test_template(self, make_model_fn): - with self.test_session() as session: + with self.test_session(): model = make_model_fn() + model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) temp_dir = self.get_temp_dir() prefix = os.path.join(temp_dir, 'ckpt') + train_x = np.random.random((3, 2)) + train_y = np.random.random((3,)) + x = constant_op.constant(train_x, dtype=dtypes.float32) - x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32) - executing_eagerly = context.executing_eagerly() - ref_y_tensor = model(x) - if not executing_eagerly: - session.run([v.initializer for v in model.variables]) - ref_y = self.evaluate(ref_y_tensor) + model.train_on_batch(train_x, train_y) model.save_weights(prefix, save_format='tf') + ref_y_before_train = model.predict(train_x) + model.train_on_batch(train_x, train_y) + ref_y_after_train = model.predict(train_x) for v in model.variables: self.evaluate( v.assign(random_ops.random_normal(shape=array_ops.shape(v)))) @@ -741,16 +746,27 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): self.addCleanup(shutil.rmtree, temp_dir) model.load_weights(prefix) - y = self.evaluate(model(x)) - self.assertAllClose(ref_y, y) + self.assertAllClose(ref_y_before_train, self.evaluate(model(x))) # Test restore-on-create if this is a subclassed Model (graph Networks # will have already created their variables). load_model = make_model_fn() load_model.load_weights(prefix) - restore_on_create_y_tensor = load_model(x) - restore_on_create_y = self.evaluate(restore_on_create_y_tensor) - self.assertAllClose(ref_y, restore_on_create_y) + self.assertAllClose( + ref_y_before_train, + self.evaluate(load_model(x))) + load_model = make_model_fn() + load_model.load_weights(prefix) + # We need to run some of the restore ops for predict(), but not all + # variables have been created yet (optimizer slot variables). Tests + # incremental restore. + load_model.predict(train_x) + load_model.compile( + loss='mse', + optimizer=training_module.RMSPropOptimizer(0.1), + metrics=['acc']) + load_model.train_on_batch(train_x, train_y) + self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x))) @test_util.run_in_graph_and_eager_modes def test_weight_loading_graph_model(self): @@ -858,5 +874,6 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase): SubclassedModel, SubclassedModelRestore, _restore_init_fn) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 39d207cc6b..315d88d418 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -671,7 +671,6 @@ class Model(Network): updates=updates, name='train_function', **self._function_kwargs) - self._post_build_cleanup() def _make_test_function(self): if not hasattr(self, 'test_function'): @@ -689,7 +688,6 @@ class Model(Network): updates=self.state_updates + self.metrics_updates, name='test_function', **self._function_kwargs) - self._post_build_cleanup() def _make_predict_function(self): if not hasattr(self, 'predict_function'): @@ -708,7 +706,6 @@ class Model(Network): updates=self.state_updates, name='predict_function', **kwargs) - self._post_build_cleanup() def _get_iterator_get_next_tensors(self, iterator): get_next_op = self._iterator_get_next.get(iterator, None) diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py index f0703c8af4..66837ee52f 100644 --- a/tensorflow/python/training/checkpointable/base.py +++ b/tensorflow/python/training/checkpointable/base.py @@ -144,7 +144,7 @@ class _CheckpointPosition(object): # process deferred restorations for it and its dependencies. restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access if restore_ops: - self._checkpoint.restore_ops.extend(restore_ops) + self._checkpoint.new_restore_ops(restore_ops) def bind_object(self, checkpointable): """Set a checkpoint<->object correspondence and process slot variables. diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py index 5d26a817d4..664b2348c0 100644 --- a/tensorflow/python/training/checkpointable/util.py +++ b/tensorflow/python/training/checkpointable/util.py @@ -101,6 +101,7 @@ class _CheckpointRestoreCoordinator(object): # this checkpoint. self.restore_ops = [] self.restore_ops_by_name = {} + self.new_restore_ops_callback = None # A mapping from optimizer proto ids to lists of slot variables to be # restored when the optimizer is tracked. Only includes slot variables whose # regular variables have already been created, and only for optimizer @@ -121,6 +122,11 @@ class _CheckpointRestoreCoordinator(object): slot_variable_id=slot_reference.slot_variable_node_id, slot_name=slot_reference.slot_name)) + def new_restore_ops(self, new_ops): + self.restore_ops.extend(new_ops) + if self.new_restore_ops_callback: + self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable + class _NameBasedRestoreCoordinator(object): """Keeps the status of a name-based checkpoint restore.""" @@ -821,6 +827,31 @@ class _LoadStatus(object): pass +def streaming_restore(status, session=None): + """When graph building, runs restore ops as soon as they come in. + + Args: + status: A _LoadStatus objects from an object-based saver's + restore(). Streaming restore from name-based checkpoints is not currently + supported. + session: A session to run new restore ops in. + """ + if context.executing_eagerly(): + # Streaming restore is the default/only behavior when executing eagerly. + return + if session is None: + session = ops.get_default_session() + if isinstance(status, NameBasedSaverStatus): + raise NotImplementedError( + "Streaming restore not supported from name-based checkpoints. File a " + "feature request if this limitation bothers you.") + status.run_restore_ops(session=session) + # pylint: disable=protected-access + status._checkpoint.new_restore_ops_callback = ( + lambda ops: session.run(ops, feed_dict=status._feed_dict)) + # pylint: enable=protected-access + + class CheckpointLoadStatus(_LoadStatus): """Checks the status of checkpoint loading and manages restore ops. @@ -992,11 +1023,13 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = ( "one this message is coming from) and use that checkpoint in the future.") -@deprecation.deprecated( - date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) class NameBasedSaverStatus(_LoadStatus): """Status for loading a name-based training checkpoint.""" + # Ideally this deprecation decorator would be on the class, but that + # interferes with isinstance checks. + @deprecation.deprecated( + date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) def __init__(self, checkpoint, root_checkpointable): self._checkpoint = checkpoint self._root_checkpointable = root_checkpointable -- cgit v1.2.3