aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variables.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-02-15 13:43:47 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-15 13:48:35 -0800
commit8745e3426713068e7061b3aae368ebb4db8dc2cc (patch)
treef5d19eaa1f926da4f5872aa9661b2cd3840a107d /tensorflow/python/ops/variables.py
parent82d67d0af2ed13bdf003e69486f3f477961ef407 (diff)
Object-based saving: Switch to "everything is Checkpointable"
The only sane way to use/test this is to have Variables be Checkpointable, so this CL includes a move of the base class to core. No public methods are exposed, and I've attempted to not throw any errors on __setattr__. Allows dynamic dependencies (track after restore) and restoring variables on assignment to a Checkpointable object, and includes the protocol buffer modifications necessary for saving information with each object. There are still some prominent TODOs: - Stop modifying the graph after the first save/restore (likely cache ops in Checkpointable objects) - Add some overridable methods for saving Python strings when restore() is called, fed when graph building rather than embedded as constants in the graph - Work on the initialization story for graph building. Currently the unit tests rely on collections for this. - Support for more objects, move the prototype modifications in checkpointable_test to core. The diff is larger than I was hoping (mostly deletions and unit tests); that could be reduced a bit (or at least "lines added" converted to "lines deleted") by diffbasing on cl/180950921, which was my first attempt at dynamic dependencies. This CL is more of a re-write than a modification, so sending that one out seems a bit silly. The unit tests are still good, though. PiperOrigin-RevId: 185893387
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r--tensorflow/python/ops/variables.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 19e3298e40..125922e296 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
from tensorflow.python.util import compat
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
@@ -36,7 +37,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export("Variable")
-class Variable(object):
+class Variable(checkpointable.Checkpointable):
"""See the @{$variables$Variables How To} for a high level overview.
A variable maintains state in the graph across calls to `run()`. You add a
@@ -306,6 +307,11 @@ class Variable(object):
if constraint is not None and not callable(constraint):
raise ValueError("The `constraint` argument must be a callable.")
+ if isinstance(initial_value, checkpointable.CheckpointInitialValue):
+ self._maybe_initialize_checkpointable()
+ self._update_uid = initial_value.checkpoint_position.restore_uid
+ initial_value = initial_value.wrapped_value
+
if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections:
collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES]
with ops.init_scope():
@@ -786,6 +792,20 @@ class Variable(object):
setattr(Variable, operator, _run_op)
+ def _scatter_tensors_from_checkpoint(self, attributes):
+ """For implementing `Checkpointable`. Return an assignment op to run."""
+ if (len(attributes) != 1
+ or checkpointable.VARIABLE_VALUE_KEY not in attributes):
+ raise ValueError(
+ ("The variable %s was restored with unexpected values (expected one "
+ "with key %s, got %s)") % (
+ self, checkpointable.VARIABLE_VALUE_KEY, attributes))
+ return self.assign(attributes[checkpointable.VARIABLE_VALUE_KEY])
+
+ def _gather_tensors_for_checkpoint(self):
+ """For implementing `Checkpointable`. This object is saveable on its own."""
+ return {checkpointable.VARIABLE_VALUE_KEY: self}
+
def _try_guard_against_uninitialized_dependencies(self, initial_value):
"""Attempt to guard against dependencies on uninitialized variables.