diff options
author | 2018-06-22 11:13:43 -0700 | |
---|---|---|
committer | 2018-06-22 11:16:40 -0700 | |
commit | 51695446ec779bd069c5623bf696e4223ff2af48 (patch) | |
tree | be13cb094132252b1f0712ed651f76954b414b1e /tensorflow/contrib/optimizer_v2 | |
parent | 49934f7d1a289325480c618c658b7c4cdb8584c6 (diff) |
Split dependency tracking out from CheckpointableBase
Some unit test fiddling, but otherwise just moving code around.
My goal is to be able to use checkpointable data structures (or something like them) in Checkpointable's __setattr__ override. Checkpointable data structures depend on Layer, so Checkpointable and CheckpointableBase need to be in seprate files (so we can have the dependency chain CheckpointableBase->Layer->CheckpointableDataStructure->Checkpointable).
This will also require changes to python/keras/engine/__init__.py (which currently requires Network and Layer be imported together), but I'll do that in a separate change.
PiperOrigin-RevId: 201712549
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 101 |
1 files changed, 39 insertions, 62 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index b6972a7a45..06ab58188a 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -43,15 +43,15 @@ from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import saver as core_saver from tensorflow.python.training import training_util -from tensorflow.python.training.checkpointable import base as checkpointable -from tensorflow.python.training.checkpointable import util as checkpointable_utils +from tensorflow.python.training.checkpointable import tracking +from tensorflow.python.training.checkpointable import util -class NonLayerCheckpointable(checkpointable.Checkpointable): +class NonLayerCheckpointable(tracking.Checkpointable): def __init__(self): super(NonLayerCheckpointable, self).__init__() - self.a_variable = checkpointable_utils.add_variable( + self.a_variable = util.add_variable( self, name="a_variable", shape=[]) @@ -88,29 +88,6 @@ class _MirroringSaveable( self._mirrored_variable.assign(tensor)) -class _OwnsMirroredVariables(checkpointable.CheckpointableBase): - """A Checkpointable object which returns a more complex SaveableObject.""" - - def __init__(self): - self.non_dep_variable = variable_scope.get_variable( - name="non_dep_variable", initializer=6., use_resource=True) - self.mirrored = variable_scope.get_variable( - name="mirrored", initializer=15., use_resource=True) - - def _gather_saveables_for_checkpoint(self): - def _saveable_factory(name=self.non_dep_variable.name): - return _MirroringSaveable( - primary_variable=self.non_dep_variable, - mirrored_variable=self.mirrored, - name=name) - return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory} - - # The Saver sorts by name before parsing, so we need a name property. - @property - def name(self): - return self.non_dep_variable.name - - class CheckpointingTests(test.TestCase): @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) @@ -122,7 +99,7 @@ class CheckpointingTests(test.TestCase): other_model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) if context.executing_eagerly(): optimizer.minimize( @@ -137,11 +114,11 @@ class CheckpointingTests(test.TestCase): optimizer.minimize( other_model(input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) named_variables, serialized_graph, _ = ( - checkpointable_utils._serialize_object_graph( + util._serialize_object_graph( root_checkpointable, saveables_cache=None)) expected_checkpoint_names = ( # Created in the root node, so no prefix. @@ -230,7 +207,7 @@ class CheckpointingTests(test.TestCase): def testSaveRestore(self): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model) input_value = constant_op.constant([[3.]]) if context.executing_eagerly(): @@ -240,7 +217,7 @@ class CheckpointingTests(test.TestCase): train_op = optimizer.minimize(model(input_value)) # TODO(allenl): Make initialization more pleasant when graph building. root_checkpointable.save_counter # pylint: disable=pointless-statement - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) prefix = os.path.join(self.get_temp_dir(), "ckpt") @@ -266,7 +243,7 @@ class CheckpointingTests(test.TestCase): # Preserve beta1_power and beta2_power when appying gradients so we can # test that they've been restored correctly. beta1=1.0, beta2=1.0) - on_create_root = checkpointable_utils.Checkpoint( + on_create_root = util.Checkpoint( optimizer=on_create_optimizer, model=on_create_model) # Deferred restoration status = on_create_root.restore(save_path=save_path) @@ -298,7 +275,7 @@ class CheckpointingTests(test.TestCase): for training_continuation in range(3): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=training_util.get_or_create_global_step()) root.restore(core_saver.latest_checkpoint(checkpoint_directory)) @@ -322,7 +299,7 @@ class CheckpointingTests(test.TestCase): with ops.Graph().as_default(): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) input_value = constant_op.constant([[3.]]) @@ -359,7 +336,7 @@ class CheckpointingTests(test.TestCase): graph=ops.get_default_graph()), test_util.device(use_gpu=True): model = MyModel() optimizer = adam.AdamOptimizer(0.001) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -392,7 +369,7 @@ class CheckpointingTests(test.TestCase): model = MyModel() # Don't actually train so we can test variable values optimizer = adam.AdamOptimizer(0.) - root = checkpointable_utils.Checkpoint( + root = util.Checkpoint( optimizer=optimizer, model=model, global_step=training_util.get_or_create_global_step()) checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) @@ -442,7 +419,7 @@ class CheckpointingTests(test.TestCase): optimizer = adam.AdamOptimizer(learning_rate=0.05) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - checkpoint = checkpointable_utils.Checkpoint( + checkpoint = util.Checkpoint( model=model, optimizer=optimizer) for _ in range(2): checkpoint.save(checkpoint_prefix) @@ -457,8 +434,8 @@ class CheckpointingTests(test.TestCase): def testDeferredSlotRestoration(self): checkpoint_directory = self.get_temp_dir() - root = checkpointable.Checkpointable() - root.var = checkpointable_utils.add_variable( + root = tracking.Checkpointable() + root.var = util.add_variable( root, name="var", initializer=0.) optimizer = adam.AdamOptimizer(0.1) if context.executing_eagerly(): @@ -468,28 +445,28 @@ class CheckpointingTests(test.TestCase): # Note that `optimizer` has not been added as a dependency of # `root`. Create a one-off grouping so that slot variables for `root.var` # get initialized too. - self.evaluate(checkpointable_utils.gather_initializers( - checkpointable_utils.Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(util.gather_initializers( + util.Checkpoint(root=root, optimizer=optimizer))) self.evaluate(train_op) self.evaluate(state_ops.assign(root.var, 12.)) - no_slots_path = checkpointable_utils.CheckpointableSaver(root).save( + no_slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "no_slots")) root.optimizer = optimizer self.evaluate(state_ops.assign(root.var, 13.)) self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), 14.)) - slots_path = checkpointable_utils.CheckpointableSaver(root).save( + slots_path = util.CheckpointableSaver(root).save( os.path.join(checkpoint_directory, "with_slots")) - new_root = checkpointable.Checkpointable() + new_root = tracking.Checkpointable() # Load the slot-containing checkpoint (deferred), then immediately overwrite # the non-slot variable (also deferred). - slot_status = checkpointable_utils.CheckpointableSaver( + slot_status = util.CheckpointableSaver( new_root).restore(slots_path) - no_slot_status = checkpointable_utils.CheckpointableSaver( + no_slot_status = util.CheckpointableSaver( new_root).restore(no_slots_path) with self.assertRaises(AssertionError): no_slot_status.assert_consumed() - new_root.var = checkpointable_utils.add_variable( + new_root.var = util.add_variable( new_root, name="var", shape=[]) no_slot_status.assert_consumed() no_slot_status.run_restore_ops() @@ -525,12 +502,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) saver.save(checkpoint_prefix) before_ops = graph.get_operations() saver.save(checkpoint_prefix) @@ -543,12 +520,12 @@ class CheckpointingTests(test.TestCase): with graph.as_default(), self.test_session(graph): checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - obj = checkpointable.Checkpointable() + obj = tracking.Checkpointable() obj.var = variable_scope.get_variable(name="v", initializer=0.) obj.opt = adam.AdamOptimizer(0.1) obj.opt.minimize(obj.var.read_value()) - self.evaluate(checkpointable_utils.gather_initializers(obj)) - saver = checkpointable_utils.CheckpointableSaver(obj) + self.evaluate(util.gather_initializers(obj)) + saver = util.CheckpointableSaver(obj) save_path = saver.save(checkpoint_prefix) saver.restore(save_path) before_ops = graph.get_operations() @@ -565,10 +542,10 @@ class CheckpointingTests(test.TestCase): first_session = session_lib.Session(graph=first_graph) with first_graph.as_default(), first_session.as_default(): first_variable = resource_variable_ops.ResourceVariable([1.]) - first_root_checkpointable = checkpointable_utils.Checkpoint( + first_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=first_variable) train_op = optimizer.minimize(first_variable.read_value) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( first_root_checkpointable)) self.evaluate(train_op) self.evaluate(first_variable.assign([1.])) @@ -581,7 +558,7 @@ class CheckpointingTests(test.TestCase): second_graph = ops.Graph() with second_graph.as_default(), session_lib.Session(graph=second_graph): second_variable = resource_variable_ops.ResourceVariable([1.]) - second_root_checkpointable = checkpointable_utils.Checkpoint( + second_root_checkpointable = util.Checkpoint( optimizer=optimizer, variable=second_variable) train_op = optimizer.minimize(second_variable.read_value) second_root_checkpointable.restore(None).initialize_or_restore() @@ -631,7 +608,7 @@ class TemplateTests(test.TestCase): save_template = template.make_template("s1", _templated) v1_save, _, v2_save = save_template() optimizer = adam.AdamOptimizer(0.0) - save_root = checkpointable_utils.Checkpoint( + save_root = util.Checkpoint( my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in optimizer.variables()]) @@ -643,7 +620,7 @@ class TemplateTests(test.TestCase): load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) - load_root = checkpointable_utils.Checkpoint( + load_root = util.Checkpoint( my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2 = load_template() @@ -664,12 +641,12 @@ class CheckpointCompatibilityTests(test.TestCase): model = MyModel() optimizer = adam.AdamOptimizer(0.001) optimizer_step = training_util.get_or_create_global_step() - root_checkpointable = checkpointable_utils.Checkpoint( + root_checkpointable = util.Checkpoint( optimizer=optimizer, model=model, optimizer_step=optimizer_step) train_op = optimizer.minimize( functools.partial(model, input_value), global_step=optimizer_step) - self.evaluate(checkpointable_utils.gather_initializers( + self.evaluate(util.gather_initializers( root_checkpointable)) self.evaluate(train_op) # A regular variable, a slot variable, and a non-slot Optimizer variable @@ -721,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase): self._set_sentinels(root) with self.assertRaises(AssertionError): self._check_sentinels(root) - object_saver = checkpointable_utils.CheckpointableSaver(root) + object_saver = util.CheckpointableSaver(root) self._set_sentinels(root) status = object_saver.restore(save_path) if context.executing_eagerly(): |