diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-11 15:58:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-11 16:01:24 -0700 |
commit | 5828842e5956825a65a5423b1ca503f72b084e62 (patch) | |
tree | f26d7c2b326f6acb16aba7d63031fd42e00733f8 /tensorflow/contrib/optimizer_v2 | |
parent | 4ca7a9157863a6d57879c598cc370583d60018d3 (diff) |
Checkpointable: Remove overzealous error checking from tf.make_template
It was checking that all variables in the Template's scope were dependencies, but Optimizer slot variables are created with the same prefix (and should not be dependencies).
Conversely, eager execution's eager slot variable creation meant that Templates create unnecessary/somewhat harmful dependencies on restored slot variables. Fixes that.
PiperOrigin-RevId: 196321999
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r-- | tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py | 45 | ||||
-rw-r--r-- | tensorflow/contrib/optimizer_v2/optimizer_v2.py | 11 |
2 files changed, 55 insertions, 1 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py index 87b2ecf565..b1f2e9d860 100644 --- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py +++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py @@ -36,8 +36,10 @@ from tensorflow.python.framework import test_util from tensorflow.python.keras._impl.keras.engine import training from tensorflow.python.keras._impl.keras.layers import core from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops +from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope from tensorflow.python.training import checkpointable from tensorflow.python.training import checkpointable_utils @@ -612,6 +614,49 @@ class CheckpointingTests(test.TestCase): self.assertAllEqual(3., self.evaluate(beta1_power)) +class TemplateTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_checkpointable_save_restore(self): + + def _templated(): + v = variable_scope.get_variable( + "v", shape=[1], initializer=init_ops.zeros_initializer(), + use_resource=True) + v2 = variable_scope.get_variable( + "v2", shape=[1], initializer=init_ops.zeros_initializer(), + use_resource=True) + return v, v + 1., v2 + + save_template = template.make_template("s1", _templated) + v1_save, _, v2_save = save_template() + optimizer = adam.AdamOptimizer(0.0) + save_root = checkpointable_utils.Checkpoint( + my_template=save_template, optimizer=optimizer) + optimizer.minimize(v1_save.read_value) + self.evaluate([v.initializer for v in optimizer.variables()]) + self.evaluate(v1_save.assign([12.])) + self.evaluate(v2_save.assign([14.])) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = save_root.save(checkpoint_prefix) + + load_template = template.make_template("s2", _templated) + load_optimizer = adam.AdamOptimizer(0.0) + load_root = checkpointable_utils.Checkpoint( + my_template=load_template, optimizer=load_optimizer) + status = load_root.restore(save_path) + var, var_plus_one, var2 = load_template() + load_optimizer.minimize(var.read_value) + self.assertEqual(2, len(load_template._checkpoint_dependencies)) + self.assertEqual("v", load_template._checkpoint_dependencies[0].name) + self.assertEqual("v2", load_template._checkpoint_dependencies[1].name) + status.assert_consumed().run_restore_ops() + self.assertAllEqual([12.], self.evaluate(var)) + self.assertAllEqual([13.], self.evaluate(var_plus_one)) + self.assertAllEqual([14.], self.evaluate(var2)) + + class CheckpointCompatibilityTests(test.TestCase): def _initialized_model(self): diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py index 46bfbb729f..694a3cebd6 100644 --- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py +++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py @@ -360,7 +360,16 @@ class _OptimizerV2State(object): """ slot_variable = self.get_slot(var=variable, name=slot_name) if (slot_variable is None and context.executing_eagerly() and - slot_variable_position.is_simple_variable()): + slot_variable_position.is_simple_variable() + # Defer slot variable creation if there is an active variable creator + # scope. Generally we'd like to eagerly create/restore slot variables + # when possible, but this may mean that scopes intended to catch + # `variable` also catch its eagerly created slot variable + # unintentionally (specifically make_template would add a dependency on + # a slot variable if not for this case). Deferring is mostly harmless + # (aside from double initialization), and makes variable creator scopes + # behave the same way they do when graph building. + and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access initializer = checkpointable.CheckpointInitialValue( checkpoint_position=slot_variable_position) slot_variable = self.create_slot( |