aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-02-27 12:34:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-27 12:38:11 -0800
commit78376e4077f4e9d293811bdbc453c6d1b93db453 (patch)
tree0d1ffec6fd741833b0169c69d6080816c8700c56 /tensorflow/python/layers
parent24a1c89187e49847fbd3575d626f1e374ce9ed18 (diff)
Make Layers Checkpointable
(This change is mostly API goldens by volume) Layers will inherit from CheckpointableBase since they do variable management themselves. A __setattr__ override would also likely slow down functional layers significantly. I believe the plan for Model is to piggyback on its existing __setattr__ override rather than having Model inherit from CheckpointableBase through Layer and Checkpointable itself. PiperOrigin-RevId: 187215512
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 8314c4aa87..2ec9971b88 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -36,12 +36,13 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export('layers.Layer')
-class Layer(object):
+class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
This is the class from which all layers inherit, implementing common
@@ -532,13 +533,17 @@ class Layer(object):
with vs.variable_scope(
self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
with ops.name_scope(self._name_scope_name(scope)):
- variable = vs.get_variable(name,
- shape=shape,
- initializer=initializer,
- dtype=dtypes.as_dtype(dtype),
- constraint=constraint,
- trainable=trainable and self.trainable,
- partitioner=partitioner)
+ variable = self._add_variable_with_custom_getter(
+ name=name,
+ shape=shape,
+ getter=vs.get_variable,
+ # Manage errors in Layer rather than Checkpointable.
+ overwrite=True,
+ initializer=initializer,
+ dtype=dtypes.as_dtype(dtype),
+ constraint=constraint,
+ trainable=trainable and self.trainable,
+ partitioner=partitioner)
if init_graph is not None: # pylint: disable=protected-access
# The variable was created and initialized in a graph.