aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/training/checkpointable.py5
-rw-r--r--tensorflow/python/training/checkpointable_utils.py14
-rw-r--r--tensorflow/python/training/checkpointable_utils_test.py3
-rw-r--r--tensorflow/python/training/saver.py70
-rw-r--r--tensorflow/python/training/saver_test.py150
5 files changed, 227 insertions, 15 deletions
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index 9bf48df22e..0b8473742c 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -26,6 +26,11 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
from tensorflow.python.util import nest
+
+# Key where the object graph proto is saved in a TensorBundle
+OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
+
+
# A key indicating a variable's value in an object's checkpointed Tensors
# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and
# the object has no dependencies, then its value may be restored on object
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
index da99d2ec31..2c4677a278 100644
--- a/tensorflow/python/training/checkpointable_utils.py
+++ b/tensorflow/python/training/checkpointable_utils.py
@@ -54,8 +54,6 @@ _OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
# attribute in checkpoint names. Used like:
# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
-# Key where the object graph proto is saved in a TensorBundle
-_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH"
class _CheckpointRestoreCoordinator(object):
@@ -680,10 +678,11 @@ class CheckpointableSaver(object):
object_graph_tensor = constant_op.constant(
graph_proto.SerializeToString(), dtype=dtypes.string)
feed_additions = None
- assert _OBJECT_GRAPH_PROTO_KEY not in named_variables
- named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=_OBJECT_GRAPH_PROTO_KEY)
+ assert checkpointable_lib.OBJECT_GRAPH_PROTO_KEY not in named_variables
+ named_variables[checkpointable_lib.OBJECT_GRAPH_PROTO_KEY] = (
+ _NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=checkpointable_lib.OBJECT_GRAPH_PROTO_KEY))
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
# save() is called so they pick up new Tensors passed to their
@@ -786,7 +785,8 @@ class CheckpointableSaver(object):
file_prefix_feed_dict = None
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
try:
- object_graph_string = reader.get_tensor(_OBJECT_GRAPH_PROTO_KEY)
+ object_graph_string = reader.get_tensor(
+ checkpointable_lib.OBJECT_GRAPH_PROTO_KEY)
except errors_impl.NotFoundError:
# The object graph proto does not exist in this checkpoint. Try again with
# name-based saving.
diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py
index ddf9820616..29fcdb70b4 100644
--- a/tensorflow/python/training/checkpointable_utils_test.py
+++ b/tensorflow/python/training/checkpointable_utils_test.py
@@ -1268,9 +1268,6 @@ class CheckpointCompatibilityTests(test.TestCase):
status.initialize_or_restore()
self._check_sentinels(root)
- # TODO(allenl): Test for the core name-based saver loading object-based
- # checkpoints once object-based checkpointing is in core.
-
def testSaveGraphLoadEager(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index e40b8d22ed..79d278cf90 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import collections
import os.path
import re
+import sys
import time
import uuid
@@ -30,8 +31,10 @@ import six
from google.protobuf import text_format
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -1340,6 +1343,9 @@ class Saver(object):
self._check_saver_def()
self._write_version = self.saver_def.version
self._save_relative_paths = save_relative_paths
+ # For compatibility with object-based checkpoints, we may build a second
+ # Saver to read the renamed keys.
+ self._object_restore_saver = None
def build(self):
if context.executing_eagerly():
@@ -1795,11 +1801,65 @@ class Saver(object):
if save_path is None:
raise ValueError("Can't load save_path when it is None.")
logging.info("Restoring parameters from %s", save_path)
- if context.executing_eagerly():
- self._build_eager(save_path, build_save=False, build_restore=True)
- else:
- sess.run(self.saver_def.restore_op_name,
- {self.saver_def.filename_tensor_name: save_path})
+ try:
+ if context.executing_eagerly():
+ self._build_eager(save_path, build_save=False, build_restore=True)
+ else:
+ sess.run(self.saver_def.restore_op_name,
+ {self.saver_def.filename_tensor_name: save_path})
+ except errors.NotFoundError:
+ exception_type, exception_value, exception_traceback = sys.exc_info()
+ # The checkpoint would not be loaded successfully as is. Try to parse it
+ # as an object-based checkpoint.
+ try:
+ reader = pywrap_tensorflow.NewCheckpointReader(save_path)
+ object_graph_string = reader.get_tensor(
+ checkpointable.OBJECT_GRAPH_PROTO_KEY)
+ except errors.NotFoundError:
+ # This is not an object-based checkpoint, or the checkpoint doesn't
+ # exist. Re-raise the original exception.
+ six.reraise(exception_type, exception_value, exception_traceback)
+ del exception_traceback # avoid reference cycles
+
+ # This is an object-based checkpoint. We'll print a warning and then do
+ # the restore.
+ logging.warning(
+ # TODO(allenl): Modify instructions for using the object-based saver
+ # once that's in core.
+ "Restoring an object-based checkpoint using a name-based saver. This "
+ "may be somewhat fragile, and will re-build the Saver. Instead, "
+ "consider loading object-based checkpoints using "
+ "tf.contrib.eager.Checkpoint().")
+ self._restore_from_object_based_checkpoint(
+ sess=sess, save_path=save_path,
+ object_graph_string=object_graph_string)
+
+ def _restore_from_object_based_checkpoint(self, sess, save_path,
+ object_graph_string):
+ """A compatibility mode for reading object-based checkpoints."""
+ object_graph_proto = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+ object_graph_proto.ParseFromString(object_graph_string)
+ names_to_keys = {}
+ for node in object_graph_proto.nodes:
+ for attribute in node.attributes:
+ names_to_keys[attribute.full_name] = attribute.checkpoint_key
+ saveables = self._builder._ValidateAndSliceInputs(self._var_list) # pylint: disable=protected-access
+ for saveable in saveables:
+ for spec in saveable.specs:
+ if spec.name not in names_to_keys:
+ raise errors.NotFoundError(
+ None, None,
+ message=("Attempting to load an object-based checkpoint using "
+ "variable names, but could not find %s in the "
+ "checkpoint.") % spec.name)
+ spec.name = names_to_keys[spec.name]
+ if self._object_restore_saver is None:
+ # Cache the Saver so multiple restore() calls don't pollute the graph when
+ # graph building. This assumes keys are consistent (i.e. this is the same
+ # type of object-based checkpoint we saw previously).
+ self._object_restore_saver = Saver(saveables)
+ self._object_restore_saver.restore(sess=sess, save_path=save_path)
@staticmethod
def _add_collection_def(meta_graph_def, key, export_scope=None):
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 14dda79979..3867c0d8da 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
+import functools
import math
import os
import random
@@ -50,6 +51,8 @@ from tensorflow.python.framework import graph_io
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import test_util
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.keras._impl.keras.layers import core
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -69,10 +72,12 @@ from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import adam
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_module
from tensorflow.python.training import saver_test_utils
+from tensorflow.python.training import training_util
from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState
from tensorflow.python.util import compat
@@ -2948,6 +2953,29 @@ class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
return self.non_dep_variable.name
+class NonLayerCheckpointable(checkpointable.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = checkpointable_utils.add_variable(
+ self, name="a_variable", shape=[])
+
+
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
@test_util.with_c_api
class CheckpointableCompatibilityTests(test.TestCase):
@@ -3011,6 +3039,128 @@ class CheckpointableCompatibilityTests(test.TestCase):
saver.restore(sess, save_path)
self.assertEqual(1, v.eval_count)
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(checkpointable_utils.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta1_power))
+
+ def testVariableNotFoundErrorRaised(self):
+ # Restore does some tricky exception handling to figure out if it should
+ # load an object-based checkpoint. Tests that the exception handling isn't
+ # too broad.
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ a = resource_variable_ops.ResourceVariable(1., name="a")
+ b = resource_variable_ops.ResourceVariable(1., name="b")
+ a_saver = saver_module.Saver([a])
+ b_saver = saver_module.Saver([b])
+ with self.test_session() as sess:
+ sess.run(a.initializer)
+ save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
+ with self.assertRaisesRegexp(
+ errors.NotFoundError, "Key b not found in checkpoint"):
+ b_saver.restore(sess=sess, save_path=save_path)
+
+ def testCheckpointNotFoundErrorRaised(self):
+ # Restore does some tricky exception handling to figure out if it should
+ # load an object-based checkpoint. Tests that the exception handling isn't
+ # too broad.
+ a = resource_variable_ops.ResourceVariable(1., name="a")
+ saver = saver_module.Saver([a])
+ with self.test_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.NotFoundError,
+ "Failed to find any matching files for path_which_does_not_exist"):
+ saver.restore(sess=sess, save_path="path_which_does_not_exist")
+
+ def testLoadFromObjectBasedGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ save_graph = ops_lib.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph) as sess:
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+ # An incompatible object-based checkpoint to check error messages
+ var = resource_variable_ops.ResourceVariable(1., name="a")
+ self.evaluate(var.initializer)
+ second_saver = checkpointable_utils.CheckpointableSaver(var)
+ second_path = second_saver.save(file_prefix=os.path.join(
+ checkpoint_directory, "second"))
+
+ restore_graph = ops_lib.Graph()
+ with restore_graph.as_default(), self.test_session(
+ graph=restore_graph) as sess:
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ saver = saver_module.Saver()
+ saver.restore(sess=sess, save_path=save_path)
+ self._check_sentinels(root)
+ before_second_restore_ops = restore_graph.get_operations()
+ # Test that multiple restores do not pollute the graph
+ saver.restore(sess=sess, save_path=save_path)
+ self.assertEqual(before_second_restore_ops,
+ restore_graph.get_operations())
+ with self.assertRaisesRegexp(errors.NotFoundError,
+ "could not find a_variable"):
+ saver.restore(sess=sess, save_path=second_path)
+
+ def testLoadFromObjectBasedEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+
+ save_graph = ops_lib.Graph()
+ with save_graph.as_default(), self.test_session(graph=save_graph):
+ root = self._initialized_model()
+ object_saver = checkpointable_utils.CheckpointableSaver(root)
+ save_path = object_saver.save(file_prefix=checkpoint_prefix)
+
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ saver = saver_module.Saver(
+ root.model.variables + root.optimizer.variables())
+ saver.restore(sess=None, save_path=save_path)
+ self._check_sentinels(root)
+
if __name__ == "__main__":
test.main()