aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-07-26 17:53:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-26 17:57:30 -0700
commit21d94be838f7a7917f072e5ec0bbe6e3593177b9 (patch)
treecdffec9ca48e7a619903acdc71811e2952fe23b6
parentc7fc5b013ba9379e2211acc0b08bdd6774dd468d (diff)
Simulate eager variable resoration in tf.keras.Model.load_weights when graph building
Previously, the first Model build after load_weights (e.g. a predict()) would trigger restore ops, and any variables added later (e.g. slot variables from an added optimizer) would not be restored when graph building. This change makes behavior consistent between eager execution and graph building by running new restore ops as they come in. PiperOrigin-RevId: 206251879
-rw-r--r--tensorflow/python/keras/engine/base_layer.py9
-rw-r--r--tensorflow/python/keras/engine/network.py23
-rw-r--r--tensorflow/python/keras/engine/saving_test.py41
-rw-r--r--tensorflow/python/keras/engine/training.py3
-rw-r--r--tensorflow/python/training/checkpointable/base.py2
-rw-r--r--tensorflow/python/training/checkpointable/util.py37
6 files changed, 68 insertions, 47 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index b41f6ee03b..7af71f17a9 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -786,17 +786,8 @@ class Layer(checkpointable.CheckpointableBase):
if hasattr(self, '_initial_weights') and self._initial_weights is not None:
self.set_weights(self._initial_weights)
del self._initial_weights
- self._post_build_cleanup()
return outputs
- def _post_build_cleanup(self):
- """Hooks to run after all sub-Layers are built."""
- # Note that in addition to Layer.__call__, this method is called by Model
- # after building a graph network (which skips __call__). It should be called
- # when possible if self.built may have switched from False to True, and is
- # idempotent.
- pass # No-op for Layers which don't override this method.
-
def apply(self, inputs, *args, **kwargs):
"""Apply the layer on a input.
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 752e9963ca..3e7e3e3d21 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import copy
-import functools
import json
import os
import weakref
@@ -144,10 +143,6 @@ class Network(base_layer.Layer):
self._checkpointable_saver = checkpointable_utils.CheckpointableSaver(
weakref.ref(self))
- # A zero-argument function which should be called and set back to None as
- # soon as the network is built (only applicable to subclassed Models). Runs
- # restore operations when graph building.
- self._in_progress_restore_finalizer = None
@checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
@@ -1423,13 +1418,9 @@ class Network(base_layer.Layer):
'load_weights).')
if not context.executing_eagerly():
session = backend.get_session()
- finalizer = functools.partial(status.run_restore_ops, session=session)
- if self.built:
- finalizer()
- else:
- # Hold on to this status object until the network is built (for
- # subclassed Models). Then we'll run restore ops if necessary.
- self._in_progress_restore_finalizer = finalizer
+ # Restore existing variables (if any) immediately, and set up a
+ # streaming restore for any variables created in the future.
+ checkpointable_utils.streaming_restore(status=status, session=session)
return status
if h5py is None:
raise ImportError(
@@ -1447,14 +1438,6 @@ class Network(base_layer.Layer):
else:
saving.load_weights_from_hdf5_group(f, self.layers)
- def _post_build_cleanup(self):
- super(Network, self)._post_build_cleanup()
- if self._in_progress_restore_finalizer is not None:
- # Runs queued restore operations left over from load_weights when graph
- # building.
- self._in_progress_restore_finalizer()
- self._in_progress_restore_finalizer = None
-
def _updated_config(self):
"""Util shared between different serialization methods.
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 030328f2a6..e029e614e0 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -722,18 +722,23 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.assertEqual(len(graph.get_operations()), op_count)
def _weight_loading_test_template(self, make_model_fn):
- with self.test_session() as session:
+ with self.test_session():
model = make_model_fn()
+ model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
temp_dir = self.get_temp_dir()
prefix = os.path.join(temp_dir, 'ckpt')
+ train_x = np.random.random((3, 2))
+ train_y = np.random.random((3,))
+ x = constant_op.constant(train_x, dtype=dtypes.float32)
- x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
- executing_eagerly = context.executing_eagerly()
- ref_y_tensor = model(x)
- if not executing_eagerly:
- session.run([v.initializer for v in model.variables])
- ref_y = self.evaluate(ref_y_tensor)
+ model.train_on_batch(train_x, train_y)
model.save_weights(prefix, save_format='tf')
+ ref_y_before_train = model.predict(train_x)
+ model.train_on_batch(train_x, train_y)
+ ref_y_after_train = model.predict(train_x)
for v in model.variables:
self.evaluate(
v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
@@ -741,16 +746,27 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
self.addCleanup(shutil.rmtree, temp_dir)
model.load_weights(prefix)
- y = self.evaluate(model(x))
- self.assertAllClose(ref_y, y)
+ self.assertAllClose(ref_y_before_train, self.evaluate(model(x)))
# Test restore-on-create if this is a subclassed Model (graph Networks
# will have already created their variables).
load_model = make_model_fn()
load_model.load_weights(prefix)
- restore_on_create_y_tensor = load_model(x)
- restore_on_create_y = self.evaluate(restore_on_create_y_tensor)
- self.assertAllClose(ref_y, restore_on_create_y)
+ self.assertAllClose(
+ ref_y_before_train,
+ self.evaluate(load_model(x)))
+ load_model = make_model_fn()
+ load_model.load_weights(prefix)
+ # We need to run some of the restore ops for predict(), but not all
+ # variables have been created yet (optimizer slot variables). Tests
+ # incremental restore.
+ load_model.predict(train_x)
+ load_model.compile(
+ loss='mse',
+ optimizer=training_module.RMSPropOptimizer(0.1),
+ metrics=['acc'])
+ load_model.train_on_batch(train_x, train_y)
+ self.assertAllClose(ref_y_after_train, self.evaluate(load_model(x)))
@test_util.run_in_graph_and_eager_modes
def test_weight_loading_graph_model(self):
@@ -858,5 +874,6 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 39d207cc6b..315d88d418 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -671,7 +671,6 @@ class Model(Network):
updates=updates,
name='train_function',
**self._function_kwargs)
- self._post_build_cleanup()
def _make_test_function(self):
if not hasattr(self, 'test_function'):
@@ -689,7 +688,6 @@ class Model(Network):
updates=self.state_updates + self.metrics_updates,
name='test_function',
**self._function_kwargs)
- self._post_build_cleanup()
def _make_predict_function(self):
if not hasattr(self, 'predict_function'):
@@ -708,7 +706,6 @@ class Model(Network):
updates=self.state_updates,
name='predict_function',
**kwargs)
- self._post_build_cleanup()
def _get_iterator_get_next_tensors(self, iterator):
get_next_op = self._iterator_get_next.get(iterator, None)
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index f0703c8af4..66837ee52f 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -144,7 +144,7 @@ class _CheckpointPosition(object):
# process deferred restorations for it and its dependencies.
restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access
if restore_ops:
- self._checkpoint.restore_ops.extend(restore_ops)
+ self._checkpoint.new_restore_ops(restore_ops)
def bind_object(self, checkpointable):
"""Set a checkpoint<->object correspondence and process slot variables.
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 5d26a817d4..664b2348c0 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -101,6 +101,7 @@ class _CheckpointRestoreCoordinator(object):
# this checkpoint.
self.restore_ops = []
self.restore_ops_by_name = {}
+ self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
# regular variables have already been created, and only for optimizer
@@ -121,6 +122,11 @@ class _CheckpointRestoreCoordinator(object):
slot_variable_id=slot_reference.slot_variable_node_id,
slot_name=slot_reference.slot_name))
+ def new_restore_ops(self, new_ops):
+ self.restore_ops.extend(new_ops)
+ if self.new_restore_ops_callback:
+ self.new_restore_ops_callback(new_ops) # pylint: disable=not-callable
+
class _NameBasedRestoreCoordinator(object):
"""Keeps the status of a name-based checkpoint restore."""
@@ -821,6 +827,31 @@ class _LoadStatus(object):
pass
+def streaming_restore(status, session=None):
+ """When graph building, runs restore ops as soon as they come in.
+
+ Args:
+ status: A _LoadStatus objects from an object-based saver's
+ restore(). Streaming restore from name-based checkpoints is not currently
+ supported.
+ session: A session to run new restore ops in.
+ """
+ if context.executing_eagerly():
+ # Streaming restore is the default/only behavior when executing eagerly.
+ return
+ if session is None:
+ session = ops.get_default_session()
+ if isinstance(status, NameBasedSaverStatus):
+ raise NotImplementedError(
+ "Streaming restore not supported from name-based checkpoints. File a "
+ "feature request if this limitation bothers you.")
+ status.run_restore_ops(session=session)
+ # pylint: disable=protected-access
+ status._checkpoint.new_restore_ops_callback = (
+ lambda ops: session.run(ops, feed_dict=status._feed_dict))
+ # pylint: enable=protected-access
+
+
class CheckpointLoadStatus(_LoadStatus):
"""Checks the status of checkpoint loading and manages restore ops.
@@ -992,11 +1023,13 @@ _DEPRECATED_RESTORE_INSTRUCTIONS = (
"one this message is coming from) and use that checkpoint in the future.")
-@deprecation.deprecated(
- date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
class NameBasedSaverStatus(_LoadStatus):
"""Status for loading a name-based training checkpoint."""
+ # Ideally this deprecation decorator would be on the class, but that
+ # interferes with isinstance checks.
+ @deprecation.deprecated(
+ date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
def __init__(self, checkpoint, root_checkpointable):
self._checkpoint = checkpoint
self._root_checkpointable = root_checkpointable