aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py21
1 files changed, 13 insertions, 8 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 8314c4aa87..2ec9971b88 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -36,12 +36,13 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import checkpointable
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export('layers.Layer')
-class Layer(object):
+class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
This is the class from which all layers inherit, implementing common
@@ -532,13 +533,17 @@ class Layer(object):
with vs.variable_scope(
self._scope, reuse=reuse, auxiliary_name_scope=False) as scope:
with ops.name_scope(self._name_scope_name(scope)):
- variable = vs.get_variable(name,
- shape=shape,
- initializer=initializer,
- dtype=dtypes.as_dtype(dtype),
- constraint=constraint,
- trainable=trainable and self.trainable,
- partitioner=partitioner)
+ variable = self._add_variable_with_custom_getter(
+ name=name,
+ shape=shape,
+ getter=vs.get_variable,
+ # Manage errors in Layer rather than Checkpointable.
+ overwrite=True,
+ initializer=initializer,
+ dtype=dtypes.as_dtype(dtype),
+ constraint=constraint,
+ trainable=trainable and self.trainable,
+ partitioner=partitioner)
if init_graph is not None: # pylint: disable=protected-access
# The variable was created and initialized in a graph.