aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-11 15:58:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-11 16:01:24 -0700
commit5828842e5956825a65a5423b1ca503f72b084e62 (patch)
treef26d7c2b326f6acb16aba7d63031fd42e00733f8 /tensorflow/contrib/optimizer_v2
parent4ca7a9157863a6d57879c598cc370583d60018d3 (diff)
Checkpointable: Remove overzealous error checking from tf.make_template
It was checking that all variables in the Template's scope were dependencies, but Optimizer slot variables are created with the same prefix (and should not be dependencies). Conversely, eager execution's eager slot variable creation meant that Templates create unnecessary/somewhat harmful dependencies on restored slot variables. Fixes that. PiperOrigin-RevId: 196321999
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py45
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py11
2 files changed, 55 insertions, 1 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 87b2ecf565..b1f2e9d860 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -36,8 +36,10 @@ from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl.keras.engine import training
from tensorflow.python.keras._impl.keras.layers import core
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import checkpointable
from tensorflow.python.training import checkpointable_utils
@@ -612,6 +614,49 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual(3., self.evaluate(beta1_power))
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.AdamOptimizer(0.0)
+ save_root = checkpointable_utils.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.AdamOptimizer(0.0)
+ load_root = checkpointable_utils.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
class CheckpointCompatibilityTests(test.TestCase):
def _initialized_model(self):
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 46bfbb729f..694a3cebd6 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -360,7 +360,16 @@ class _OptimizerV2State(object):
"""
slot_variable = self.get_slot(var=variable, name=slot_name)
if (slot_variable is None and context.executing_eagerly() and
- slot_variable_position.is_simple_variable()):
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
initializer = checkpointable.CheckpointInitialValue(
checkpoint_position=slot_variable_position)
slot_variable = self.create_slot(