aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/optimizer_v2
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-06-22 11:13:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 11:16:40 -0700
commit51695446ec779bd069c5623bf696e4223ff2af48 (patch)
treebe13cb094132252b1f0712ed651f76954b414b1e /tensorflow/contrib/optimizer_v2
parent49934f7d1a289325480c618c658b7c4cdb8584c6 (diff)
Split dependency tracking out from CheckpointableBase
Some unit test fiddling, but otherwise just moving code around. My goal is to be able to use checkpointable data structures (or something like them) in Checkpointable's __setattr__ override. Checkpointable data structures depend on Layer, so Checkpointable and CheckpointableBase need to be in seprate files (so we can have the dependency chain CheckpointableBase->Layer->CheckpointableDataStructure->Checkpointable). This will also require changes to python/keras/engine/__init__.py (which currently requires Network and Layer be imported together), but I'll do that in a separate change. PiperOrigin-RevId: 201712549
Diffstat (limited to 'tensorflow/contrib/optimizer_v2')
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py101
1 files changed, 39 insertions, 62 deletions
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index b6972a7a45..06ab58188a 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -43,15 +43,15 @@ from tensorflow.python.ops import template
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
-class NonLayerCheckpointable(checkpointable.Checkpointable):
+class NonLayerCheckpointable(tracking.Checkpointable):
def __init__(self):
super(NonLayerCheckpointable, self).__init__()
- self.a_variable = checkpointable_utils.add_variable(
+ self.a_variable = util.add_variable(
self, name="a_variable", shape=[])
@@ -88,29 +88,6 @@ class _MirroringSaveable(
self._mirrored_variable.assign(tensor))
-class _OwnsMirroredVariables(checkpointable.CheckpointableBase):
- """A Checkpointable object which returns a more complex SaveableObject."""
-
- def __init__(self):
- self.non_dep_variable = variable_scope.get_variable(
- name="non_dep_variable", initializer=6., use_resource=True)
- self.mirrored = variable_scope.get_variable(
- name="mirrored", initializer=15., use_resource=True)
-
- def _gather_saveables_for_checkpoint(self):
- def _saveable_factory(name=self.non_dep_variable.name):
- return _MirroringSaveable(
- primary_variable=self.non_dep_variable,
- mirrored_variable=self.mirrored,
- name=name)
- return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
-
- # The Saver sorts by name before parsing, so we need a name property.
- @property
- def name(self):
- return self.non_dep_variable.name
-
-
class CheckpointingTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
@@ -122,7 +99,7 @@ class CheckpointingTests(test.TestCase):
other_model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
if context.executing_eagerly():
optimizer.minimize(
@@ -137,11 +114,11 @@ class CheckpointingTests(test.TestCase):
optimizer.minimize(
other_model(input_value),
global_step=optimizer_step)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
named_variables, serialized_graph, _ = (
- checkpointable_utils._serialize_object_graph(
+ util._serialize_object_graph(
root_checkpointable, saveables_cache=None))
expected_checkpoint_names = (
# Created in the root node, so no prefix.
@@ -230,7 +207,7 @@ class CheckpointingTests(test.TestCase):
def testSaveRestore(self):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model)
input_value = constant_op.constant([[3.]])
if context.executing_eagerly():
@@ -240,7 +217,7 @@ class CheckpointingTests(test.TestCase):
train_op = optimizer.minimize(model(input_value))
# TODO(allenl): Make initialization more pleasant when graph building.
root_checkpointable.save_counter # pylint: disable=pointless-statement
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
prefix = os.path.join(self.get_temp_dir(), "ckpt")
@@ -266,7 +243,7 @@ class CheckpointingTests(test.TestCase):
# Preserve beta1_power and beta2_power when appying gradients so we can
# test that they've been restored correctly.
beta1=1.0, beta2=1.0)
- on_create_root = checkpointable_utils.Checkpoint(
+ on_create_root = util.Checkpoint(
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
@@ -298,7 +275,7 @@ class CheckpointingTests(test.TestCase):
for training_continuation in range(3):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
optimizer_step=training_util.get_or_create_global_step())
root.restore(core_saver.latest_checkpoint(checkpoint_directory))
@@ -322,7 +299,7 @@ class CheckpointingTests(test.TestCase):
with ops.Graph().as_default():
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
input_value = constant_op.constant([[3.]])
@@ -359,7 +336,7 @@ class CheckpointingTests(test.TestCase):
graph=ops.get_default_graph()), test_util.device(use_gpu=True):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
@@ -392,7 +369,7 @@ class CheckpointingTests(test.TestCase):
model = MyModel()
# Don't actually train so we can test variable values
optimizer = adam.AdamOptimizer(0.)
- root = checkpointable_utils.Checkpoint(
+ root = util.Checkpoint(
optimizer=optimizer, model=model,
global_step=training_util.get_or_create_global_step())
checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory)
@@ -442,7 +419,7 @@ class CheckpointingTests(test.TestCase):
optimizer = adam.AdamOptimizer(learning_rate=0.05)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- checkpoint = checkpointable_utils.Checkpoint(
+ checkpoint = util.Checkpoint(
model=model, optimizer=optimizer)
for _ in range(2):
checkpoint.save(checkpoint_prefix)
@@ -457,8 +434,8 @@ class CheckpointingTests(test.TestCase):
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = checkpointable.Checkpointable()
- root.var = checkpointable_utils.add_variable(
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
if context.executing_eagerly():
@@ -468,28 +445,28 @@ class CheckpointingTests(test.TestCase):
# Note that `optimizer` has not been added as a dependency of
# `root`. Create a one-off grouping so that slot variables for `root.var`
# get initialized too.
- self.evaluate(checkpointable_utils.gather_initializers(
- checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
self.evaluate(state_ops.assign(root.var, 12.))
- no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
+ no_slots_path = util.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
- slots_path = checkpointable_utils.CheckpointableSaver(root).save(
+ slots_path = util.CheckpointableSaver(root).save(
os.path.join(checkpoint_directory, "with_slots"))
- new_root = checkpointable.Checkpointable()
+ new_root = tracking.Checkpointable()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
- slot_status = checkpointable_utils.CheckpointableSaver(
+ slot_status = util.CheckpointableSaver(
new_root).restore(slots_path)
- no_slot_status = checkpointable_utils.CheckpointableSaver(
+ no_slot_status = util.CheckpointableSaver(
new_root).restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
- new_root.var = checkpointable_utils.add_variable(
+ new_root.var = util.add_variable(
new_root, name="var", shape=[])
no_slot_status.assert_consumed()
no_slot_status.run_restore_ops()
@@ -525,12 +502,12 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
saver.save(checkpoint_prefix)
before_ops = graph.get_operations()
saver.save(checkpoint_prefix)
@@ -543,12 +520,12 @@ class CheckpointingTests(test.TestCase):
with graph.as_default(), self.test_session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = checkpointable.Checkpointable()
+ obj = tracking.Checkpointable()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
- self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
save_path = saver.save(checkpoint_prefix)
saver.restore(save_path)
before_ops = graph.get_operations()
@@ -565,10 +542,10 @@ class CheckpointingTests(test.TestCase):
first_session = session_lib.Session(graph=first_graph)
with first_graph.as_default(), first_session.as_default():
first_variable = resource_variable_ops.ResourceVariable([1.])
- first_root_checkpointable = checkpointable_utils.Checkpoint(
+ first_root_checkpointable = util.Checkpoint(
optimizer=optimizer, variable=first_variable)
train_op = optimizer.minimize(first_variable.read_value)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
first_root_checkpointable))
self.evaluate(train_op)
self.evaluate(first_variable.assign([1.]))
@@ -581,7 +558,7 @@ class CheckpointingTests(test.TestCase):
second_graph = ops.Graph()
with second_graph.as_default(), session_lib.Session(graph=second_graph):
second_variable = resource_variable_ops.ResourceVariable([1.])
- second_root_checkpointable = checkpointable_utils.Checkpoint(
+ second_root_checkpointable = util.Checkpoint(
optimizer=optimizer, variable=second_variable)
train_op = optimizer.minimize(second_variable.read_value)
second_root_checkpointable.restore(None).initialize_or_restore()
@@ -631,7 +608,7 @@ class TemplateTests(test.TestCase):
save_template = template.make_template("s1", _templated)
v1_save, _, v2_save = save_template()
optimizer = adam.AdamOptimizer(0.0)
- save_root = checkpointable_utils.Checkpoint(
+ save_root = util.Checkpoint(
my_template=save_template, optimizer=optimizer)
optimizer.minimize(v1_save.read_value)
self.evaluate([v.initializer for v in optimizer.variables()])
@@ -643,7 +620,7 @@ class TemplateTests(test.TestCase):
load_template = template.make_template("s2", _templated)
load_optimizer = adam.AdamOptimizer(0.0)
- load_root = checkpointable_utils.Checkpoint(
+ load_root = util.Checkpoint(
my_template=load_template, optimizer=load_optimizer)
status = load_root.restore(save_path)
var, var_plus_one, var2 = load_template()
@@ -664,12 +641,12 @@ class CheckpointCompatibilityTests(test.TestCase):
model = MyModel()
optimizer = adam.AdamOptimizer(0.001)
optimizer_step = training_util.get_or_create_global_step()
- root_checkpointable = checkpointable_utils.Checkpoint(
+ root_checkpointable = util.Checkpoint(
optimizer=optimizer, model=model, optimizer_step=optimizer_step)
train_op = optimizer.minimize(
functools.partial(model, input_value),
global_step=optimizer_step)
- self.evaluate(checkpointable_utils.gather_initializers(
+ self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
# A regular variable, a slot variable, and a non-slot Optimizer variable
@@ -721,7 +698,7 @@ class CheckpointCompatibilityTests(test.TestCase):
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = util.CheckpointableSaver(root)
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():