diff options
author | 2018-02-27 12:34:17 -0800 | |
---|---|---|
committer | 2018-02-27 12:38:11 -0800 | |
commit | 78376e4077f4e9d293811bdbc453c6d1b93db453 (patch) | |
tree | 0d1ffec6fd741833b0169c69d6080816c8700c56 /tensorflow/python/layers | |
parent | 24a1c89187e49847fbd3575d626f1e374ce9ed18 (diff) |
Make Layers Checkpointable
(This change is mostly API goldens by volume)
Layers will inherit from CheckpointableBase since they do variable management
themselves. A __setattr__ override would also likely slow down functional layers
significantly.
I believe the plan for Model is to piggyback on its existing __setattr__
override rather than having Model inherit from CheckpointableBase through Layer
and Checkpointable itself.
PiperOrigin-RevId: 187215512
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 21 |
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 8314c4aa87..2ec9971b88 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -36,12 +36,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @tf_export('layers.Layer') -class Layer(object): +class Layer(checkpointable.CheckpointableBase): """Base layer class. This is the class from which all layers inherit, implementing common @@ -532,13 +533,17 @@ class Layer(object): with vs.variable_scope( self._scope, reuse=reuse, auxiliary_name_scope=False) as scope: with ops.name_scope(self._name_scope_name(scope)): - variable = vs.get_variable(name, - shape=shape, - initializer=initializer, - dtype=dtypes.as_dtype(dtype), - constraint=constraint, - trainable=trainable and self.trainable, - partitioner=partitioner) + variable = self._add_variable_with_custom_getter( + name=name, + shape=shape, + getter=vs.get_variable, + # Manage errors in Layer rather than Checkpointable. + overwrite=True, + initializer=initializer, + dtype=dtypes.as_dtype(dtype), + constraint=constraint, + trainable=trainable and self.trainable, + partitioner=partitioner) if init_graph is not None: # pylint: disable=protected-access # The variable was created and initialized in a graph. |