aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/saver_test.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-04-13 14:32:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 14:35:26 -0700
commit8600d918a63c658b9b79ba96ee821c903ba3ee94 (patch)
treea8d1fd808c2311b0f2c2e31611aaff21f12649f9 /tensorflow/python/training/saver_test.py
parentbf724a8ced3710ed2234f25748ed7719e319d78c (diff)
Allow tf.train.Saver to load object-based checkpoints (using names)
This is the second part of the compatibility story. Object-based checkpointing APIs can already read name-based checkpoints, and now the name-based APIs can read object-based checkpoints by looking up the modified keys in the object graph proto. PiperOrigin-RevId: 192824907
Diffstat (limited to 'tensorflow/python/training/saver_test.py')
-rw-r--r--tensorflow/python/training/saver_test.py150
1 files changed, 150 insertions, 0 deletions
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 14dda79979..3867c0d8da 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
+import functools
import math
import os
import random
@@ -50,6 +51,8 @@ from tensorflow.python.framework import graph_io
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -69,10 +72,12 @@ from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util import compat
@@ -2948,6 +2953,29 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
return self.non_dep_variable.name
+class NonLayerCheckpointable(checkpointable.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = checkpointable_utils.add_variable(
+ self, name="a_variable", shape=[])
+
+
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
@test_util.with_c_api
class CheckpointableCompatibilityTests(test.TestCase):
@@ -3011,6 +3039,128 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver.restore(sess, save_path)
self.assertEqual(1, v.eval_count)
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(checkpointable_utils.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta1_power))
+
+ def testVariableNotFoundErrorRaised(self):
+ # Restore does some tricky exception handling to figure out if it should
+ # load an object-based checkpoint. Tests that the exception handling isn't
+ # too broad.
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ a = resource_variable_ops.ResourceVariable(1., name="a")
+ b = resource_variable_ops.ResourceVariable(1., name="b")
+ a_saver = saver_module.Saver([a])
+ b_saver = saver_module.Saver([b])
+ with self.test_session() as sess:
+ sess.run(a.initializer)
+ save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+ with self.assertRaisesRegexp(
+ errors.NotFoundError, "Key b not found in checkpoint"):
+ b_saver.restore(sess=sess, save_path=save_path)
+
+ def testCheckpointNotFoundErrorRaised(self):
+ # Restore does some tricky exception handling to figure out if it should
+ # load an object-based checkpoint. Tests that the exception handling isn't
+ # too broad.
+ a = resource_variable_ops.ResourceVariable(1., name="a")
+ saver = saver_module.Saver([a])
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.NotFoundError,
+ "Failed to find any matching files for path_which_does_not_exist"):
+ saver.restore(sess=sess, save_path="path_which_does_not_exist")
+
+ def testLoadFromObjectBasedGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ save_graph = ops_lib.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+ # An incompatible object-based checkpoint to check error messages
+ var = resource_variable_ops.ResourceVariable(1., name="a")
+ self.evaluate(var.initializer)
+ second_saver = checkpointable_utils.CheckpointableSaver(var)
+ second_path = second_saver.save(file_prefix=os.path.join(
+ checkpoint_directory, "second"))
+
+ restore_graph = ops_lib.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as sess:
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ saver = saver_module.Saver()
+ saver.restore(sess=sess, save_path=save_path)
+ self._check_sentinels(root)
+ before_second_restore_ops = restore_graph.get_operations()
+ # Test that multiple restores do not pollute the graph
+ saver.restore(sess=sess, save_path=save_path)
+ self.assertEqual(before_second_restore_ops,
+ restore_graph.get_operations())
+ with self.assertRaisesRegexp(errors.NotFoundError,
+ "could not find a_variable"):
+ saver.restore(sess=sess, save_path=second_path)
+
+ def testLoadFromObjectBasedEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ save_graph = ops_lib.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph):
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ saver = saver_module.Saver(
+ root.model.variables + root.optimizer.variables())
+ saver.restore(sess=None, save_path=save_path)
+ self._check_sentinels(root)
+
if __name__ == "__main__":
test.main()