diff options
author | Allen Lavoie <allenl@google.com> | 2018-02-15 13:43:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-15 13:48:35 -0800 |
commit | 8745e3426713068e7061b3aae368ebb4db8dc2cc (patch) | |
tree | f5d19eaa1f926da4f5872aa9661b2cd3840a107d /tensorflow/python/ops/variables.py | |
parent | 82d67d0af2ed13bdf003e69486f3f477961ef407 (diff) |
Object-based saving: Switch to "everything is Checkpointable"
The only sane way to use/test this is to have Variables be Checkpointable, so this CL includes a move of the base class to core. No public methods are exposed, and I've attempted to not throw any errors on __setattr__.
Allows dynamic dependencies (track after restore) and restoring variables on assignment to a Checkpointable object, and includes the protocol buffer modifications necessary for saving information with each object.
There are still some prominent TODOs:
- Stop modifying the graph after the first save/restore (likely cache ops in Checkpointable objects)
- Add some overridable methods for saving Python strings when restore() is called, fed when graph building rather than embedded as constants in the graph
- Work on the initialization story for graph building. Currently the unit tests rely on collections for this.
- Support for more objects, move the prototype modifications in checkpointable_test to core.
The diff is larger than I was hoping (mostly deletions and unit tests); that could be reduced a bit (or at least "lines added" converted to "lines deleted") by diffbasing on cl/180950921, which was my first attempt at dynamic dependencies. This CL is more of a re-write than a modification, so sending that one out seems a bit silly. The unit tests are still good, though.
PiperOrigin-RevId: 185893387
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r-- | tensorflow/python/ops/variables.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 19e3298e40..125922e296 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import compat from tensorflow.python.util import tf_should_use from tensorflow.python.util.deprecation import deprecated @@ -36,7 +37,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("Variable") -class Variable(object): +class Variable(checkpointable.Checkpointable): """See the @{$variables$Variables How To} for a high level overview. A variable maintains state in the graph across calls to `run()`. You add a @@ -306,6 +307,11 @@ class Variable(object): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -786,6 +792,20 @@ class Variable(object): setattr(Variable, operator, _run_op) + def _scatter_tensors_from_checkpoint(self, attributes): + """For implementing `Checkpointable`. Return an assignment op to run.""" + if (len(attributes) != 1 + or checkpointable.VARIABLE_VALUE_KEY not in attributes): + raise ValueError( + ("The variable %s was restored with unexpected values (expected one " + "with key %s, got %s)") % ( + self, checkpointable.VARIABLE_VALUE_KEY, attributes)) + return self.assign(attributes[checkpointable.VARIABLE_VALUE_KEY]) + + def _gather_tensors_for_checkpoint(self): + """For implementing `Checkpointable`. This object is saveable on its own.""" + return {checkpointable.VARIABLE_VALUE_KEY: self} + def _try_guard_against_uninitialized_dependencies(self, initial_value): """Attempt to guard against dependencies on uninitialized variables. |