diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-09 15:56:43 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-09 15:59:21 -0700 |
commit | ef58a46b730155717f1b03abb20767c1924ad05e (patch) | |
tree | 6d9509f18b878d07f9e320a566a559c83bac2613 /tensorflow/contrib/optimizer_v2 | |
parent | 22b8b9a528c658144a16dce19ba506561abae2ee (diff) |
Support saving Python state with object-based checkpoints
Allows SaveableObjects to specify feed dict addition callbacks for object-based saving.
For now just saves get_config() with Layers. Doesn't do any loading, and there isn't quite enough information to reconstruct a Model yet (needs topology).
My plan is to get Models to the point where they can be reconstructed from object-based checkpoints (probably one more change), add in SavedModel export (assuming no dynamic control flow for now), then add this "SavedModel+Python" format to Model.save / load_model.
PiperOrigin-RevId: 196043183
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 43 |
1 files changed, 19 insertions, 24 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 9e2858d00f..87b2ecf565 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.keras._impl.keras.engine import training @@ -139,8 +138,9 @@ class CheckpointingTests(test.TestCase): self.evaluate(checkpointable_utils.gather_initializers( root_checkpointable)) self.evaluate(train_op) - named_variables, serialized_graph = ( - checkpointable_utils._serialize_object_graph(root_checkpointable)) + named_variables, serialized_graph, _ = ( + checkpointable_utils._serialize_object_graph( + root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. "optimizer_step", @@ -163,24 +163,29 @@ class CheckpointingTests(test.TestCase): suffix = "/.ATTRIBUTES/VARIABLE_VALUE" expected_checkpoint_names = [ name + suffix for name in expected_checkpoint_names] + # The Dense layers also save get_config() JSON + expected_checkpoint_names.extend( + ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON", + "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"]) + named_variables = {v.name: v for v in named_variables} six.assertCountEqual(self, expected_checkpoint_names, named_variables.keys()) # Check that we've mapped to the right variable objects (not exhaustive) self.assertEqual( - "global_step:0", - named_variables["optimizer_step" + suffix].name) + "global_step", + named_variables["optimizer_step" + suffix].full_name) self.assertEqual( - "my_model/dense_1/kernel:0", - named_variables["model/_second/kernel" + suffix].name) + "my_model/dense_1/kernel", + named_variables["model/_second/kernel" + suffix].full_name) self.assertEqual( - "my_model/dense/kernel:0", - named_variables["model/_named_dense/kernel" + suffix].name) + "my_model/dense/kernel", + named_variables["model/_named_dense/kernel" + suffix].full_name) self.assertEqual( - "beta1_power:0", - named_variables["optimizer/beta1_power" + suffix].name) + "beta1_power", + named_variables["optimizer/beta1_power" + suffix].full_name) self.assertEqual( - "beta2_power:0", - named_variables["optimizer/beta2_power" + suffix].name) + "beta2_power", + named_variables["optimizer/beta2_power" + suffix].full_name) # Spot check the generated protocol buffers. self.assertEqual("optimizer", serialized_graph.nodes[0].children[1].local_name) @@ -205,7 +210,7 @@ class CheckpointingTests(test.TestCase): self.assertEqual( "my_model/dense/kernel/Adam:0", optimizer.get_slot( - var=named_variables["model/_named_dense/kernel" + suffix], + var=model._named_dense.kernel, name="m").name) self.assertEqual( "model/_named_dense/kernel" + suffix, @@ -417,16 +422,6 @@ class CheckpointingTests(test.TestCase): self.evaluate(root.save_counter)) # pylint: enable=cell-var-from-loop - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - checkpointable_utils.add_variable( - root, name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable_utils._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - def testAnonymousVarsInInit(self): class Model(training.Model): |