aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/checkpoint
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-06-22 11:13:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 11:16:40 -0700
commit51695446ec779bd069c5623bf696e4223ff2af48 (patch)
treebe13cb094132252b1f0712ed651f76954b414b1e /tensorflow/contrib/checkpoint
parent49934f7d1a289325480c618c658b7c4cdb8584c6 (diff)
Split dependency tracking out from CheckpointableBase
Some unit test fiddling, but otherwise just moving code around. My goal is to be able to use checkpointable data structures (or something like them) in Checkpointable's __setattr__ override. Checkpointable data structures depend on Layer, so Checkpointable and CheckpointableBase need to be in seprate files (so we can have the dependency chain CheckpointableBase->Layer->CheckpointableDataStructure->Checkpointable). This will also require changes to python/keras/engine/__init__.py (which currently requires Network and Layer be imported together), but I'll do that in a separate change. PiperOrigin-RevId: 201712549
Diffstat (limited to 'tensorflow/contrib/checkpoint')
-rw-r--r--tensorflow/contrib/checkpoint/__init__.py4
-rw-r--r--tensorflow/contrib/checkpoint/python/containers_test.py16
-rw-r--r--tensorflow/contrib/checkpoint/python/split_dependency_test.py19
3 files changed, 20 insertions, 19 deletions
diff --git a/tensorflow/contrib/checkpoint/__init__.py b/tensorflow/contrib/checkpoint/__init__.py
index 38856417c0..8c1ce5c2a2 100644
--- a/tensorflow/contrib/checkpoint/__init__.py
+++ b/tensorflow/contrib/checkpoint/__init__.py
@@ -41,11 +41,11 @@ from tensorflow.contrib.checkpoint.python.containers import UniqueNameTracker
from tensorflow.contrib.checkpoint.python.split_dependency import split_dependency
from tensorflow.contrib.checkpoint.python.visualize import dot_graph_from_checkpoint
from tensorflow.core.protobuf.checkpointable_object_graph_pb2 import CheckpointableObjectGraph
-from tensorflow.python.training.checkpointable.base import Checkpointable
from tensorflow.python.training.checkpointable.base import CheckpointableBase
-from tensorflow.python.training.checkpointable.base import NoDependency
from tensorflow.python.training.checkpointable.data_structures import List
from tensorflow.python.training.checkpointable.data_structures import Mapping
+from tensorflow.python.training.checkpointable.tracking import Checkpointable
+from tensorflow.python.training.checkpointable.tracking import NoDependency
from tensorflow.python.training.checkpointable.util import capture_dependencies
from tensorflow.python.training.checkpointable.util import list_objects
from tensorflow.python.training.checkpointable.util import object_metadata
diff --git a/tensorflow/contrib/checkpoint/python/containers_test.py b/tensorflow/contrib/checkpoint/python/containers_test.py
index 12b99d3e22..64d056bd68 100644
--- a/tensorflow/contrib/checkpoint/python/containers_test.py
+++ b/tensorflow/contrib/checkpoint/python/containers_test.py
@@ -26,8 +26,8 @@ from tensorflow.python.keras import layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
class UniqueNameTrackerTests(test.TestCase):
@@ -48,11 +48,11 @@ class UniqueNameTrackerTests(test.TestCase):
slots.track(y, "y")
self.evaluate((x1.initializer, x2.initializer, x3.initializer,
y.initializer))
- save_root = checkpointable_utils.Checkpoint(slots=slots)
+ save_root = util.Checkpoint(slots=slots)
save_path = save_root.save(checkpoint_prefix)
- restore_slots = checkpointable.Checkpointable()
- restore_root = checkpointable_utils.Checkpoint(
+ restore_slots = tracking.Checkpointable()
+ restore_root = util.Checkpoint(
slots=restore_slots)
status = restore_root.restore(save_path)
restore_slots.x = resource_variable_ops.ResourceVariable(0.)
@@ -67,7 +67,7 @@ class UniqueNameTrackerTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testExample(self):
- class SlotManager(checkpointable.Checkpointable):
+ class SlotManager(tracking.Checkpointable):
def __init__(self):
self.slotdeps = containers.UniqueNameTracker()
@@ -83,11 +83,11 @@ class UniqueNameTrackerTests(test.TestCase):
manager = SlotManager()
self.evaluate([v.initializer for v in manager.slots])
- checkpoint = checkpointable_utils.Checkpoint(slot_manager=manager)
+ checkpoint = util.Checkpoint(slot_manager=manager)
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
save_path = checkpoint.save(checkpoint_prefix)
- metadata = checkpointable_utils.object_metadata(save_path)
+ metadata = util.object_metadata(save_path)
dependency_names = []
for node in metadata.nodes:
for child in node.children:
diff --git a/tensorflow/contrib/checkpoint/python/split_dependency_test.py b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
index 43c5d6515b..00a805af25 100644
--- a/tensorflow/contrib/checkpoint/python/split_dependency_test.py
+++ b/tensorflow/contrib/checkpoint/python/split_dependency_test.py
@@ -23,8 +23,9 @@ from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.training.checkpointable import base as checkpointable
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
def _split_variable_closure(variable):
@@ -43,7 +44,7 @@ def _combine_variable_closure(variable):
return _consume_restore_buffer_fn
-class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
+class SaveTensorSlicesAsDeps(base.CheckpointableBase):
def __init__(self):
self.combined = resource_variable_ops.ResourceVariable([0., 0., 0., 0.])
@@ -58,14 +59,14 @@ class SaveTensorSlicesAsDeps(checkpointable.CheckpointableBase):
self._track_checkpointable(dep, name=name)
-class HasRegularDeps(checkpointable.Checkpointable):
+class HasRegularDeps(tracking.Checkpointable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
self.second_half = resource_variable_ops.ResourceVariable([0., 0.])
-class OnlyOneDep(checkpointable.Checkpointable):
+class OnlyOneDep(tracking.Checkpointable):
def __init__(self):
self.first_half = resource_variable_ops.ResourceVariable([0., 0.])
@@ -75,7 +76,7 @@ class SplitTests(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def testSaveRestoreSplitDep(self):
- save_checkpoint = checkpointable_utils.Checkpoint(
+ save_checkpoint = util.Checkpoint(
dep=SaveTensorSlicesAsDeps())
self.evaluate(save_checkpoint.dep.combined.assign([1., 2., 3., 4.]))
checkpoint_directory = self.get_temp_dir()
@@ -83,7 +84,7 @@ class SplitTests(test.TestCase):
save_path = save_checkpoint.save(checkpoint_prefix)
regular_deps = HasRegularDeps()
- regular_restore_checkpoint = checkpointable_utils.Checkpoint(
+ regular_restore_checkpoint = util.Checkpoint(
dep=regular_deps)
regular_restore_checkpoint.restore(
save_path).assert_consumed().run_restore_ops()
@@ -91,7 +92,7 @@ class SplitTests(test.TestCase):
self.assertAllEqual([3., 4.], self.evaluate(regular_deps.second_half))
one_dep = OnlyOneDep()
- one_dep_restore_checkpoint = checkpointable_utils.Checkpoint(dep=one_dep)
+ one_dep_restore_checkpoint = util.Checkpoint(dep=one_dep)
status = one_dep_restore_checkpoint.restore(save_path)
with self.assertRaises(AssertionError):
# Missing the second dependency.
@@ -99,7 +100,7 @@ class SplitTests(test.TestCase):
status.run_restore_ops()
self.assertAllEqual([1., 2.], self.evaluate(one_dep.first_half))
- restore_checkpoint = checkpointable_utils.Checkpoint()
+ restore_checkpoint = util.Checkpoint()
status = restore_checkpoint.restore(save_path)
restore_checkpoint.dep = SaveTensorSlicesAsDeps()
status.assert_consumed().run_restore_ops()