aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-09 15:56:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 15:59:21 -0700
commitef58a46b730155717f1b03abb20767c1924ad05e (patch)
tree6d9509f18b878d07f9e320a566a559c83bac2613 /tensorflow/contrib/optimizer_v2
parent22b8b9a528c658144a16dce19ba506561abae2ee (diff)
Support saving Python state with object-based checkpoints
Allows SaveableObjects to specify feed dict addition callbacks for object-based saving. For now just saves get_config() with Layers. Doesn't do any loading, and there isn't quite enough information to reconstruct a Model yet (needs topology). My plan is to get Models to the point where they can be reconstructed from object-based checkpoints (probably one more change), add in SavedModel export (assuming no dynamic control flow for now), then add this "SavedModel+Python" format to Model.save / load_model. PiperOrigin-RevId: 196043183
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py43
1 files changed, 19 insertions, 24 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 9e2858d00f..87b2ecf565 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -31,7 +31,6 @@ from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl.keras.engine import training
@@ -139,8 +138,9 @@ class CheckpointingTests(test.TestCase):
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph = (
- checkpointable_utils._serialize_object_graph(root_checkpointable))
+ named_variables, serialized_graph, _ = (
+ checkpointable_utils._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@@ -163,24 +163,29 @@ class CheckpointingTests(test.TestCase):
suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
expected_checkpoint_names = [
name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
six.assertCountEqual(self, expected_checkpoint_names,
named_variables.keys())
# Check that we've mapped to the right variable objects (not exhaustive)
self.assertEqual(
- "global_step:0",
- named_variables["optimizer_step" + suffix].name)
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
self.assertEqual(
- "my_model/dense_1/kernel:0",
- named_variables["model/_second/kernel" + suffix].name)
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
self.assertEqual(
- "my_model/dense/kernel:0",
- named_variables["model/_named_dense/kernel" + suffix].name)
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
self.assertEqual(
- "beta1_power:0",
- named_variables["optimizer/beta1_power" + suffix].name)
+ "beta1_power",
+ named_variables["optimizer/beta1_power" + suffix].full_name)
self.assertEqual(
- "beta2_power:0",
- named_variables["optimizer/beta2_power" + suffix].name)
+ "beta2_power",
+ named_variables["optimizer/beta2_power" + suffix].full_name)
# Spot check the generated protocol buffers.
self.assertEqual("optimizer",
serialized_graph.nodes[0].children[1].local_name)
@@ -205,7 +210,7 @@ class CheckpointingTests(test.TestCase):
self.assertEqual(
"my_model/dense/kernel/Adam:0",
optimizer.get_slot(
- var=named_variables["model/_named_dense/kernel" + suffix],
+ var=model._named_dense.kernel,
name="m").name)
self.assertEqual(
"model/_named_dense/kernel" + suffix,
@@ -417,16 +422,6 @@ class CheckpointingTests(test.TestCase):
self.evaluate(root.save_counter))
# pylint: enable=cell-var-from-loop
- def _get_checkpoint_name(self, name):
- root = checkpointable.Checkpointable()
- checkpointable_utils.add_variable(
- root, name=name, shape=[1, 2], dtype=dtypes.float64)
- named_variables, _ = checkpointable_utils._serialize_object_graph(root)
- checkpoint_name, = named_variables.keys()
- with ops.name_scope("root/" + checkpoint_name):
- pass # Make sure we can use this as an op name if we prefix it.
- return checkpoint_name
-
def testAnonymousVarsInInit(self):
class Model(training.Model):