aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/network.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/network.py')
-rw-r--r--tensorflow/python/keras/engine/network.py60
1 files changed, 47 insertions, 13 deletions
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index aa84eaa8ab..752e9963ca 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -81,6 +81,20 @@ class Network(base_layer.Layer):
# Subclassed network
self._init_subclassed_network(**kwargs)
+ # Several Network methods have "no_automatic_dependency_tracking"
+ # annotations. Since Network does automatic dependency tracking on attribute
+ # assignment, including for common data structures such as lists, by default
+ # we'd have quite a few empty dependencies which users don't care about (or
+ # would need some way to ignore dependencies automatically, which is confusing
+ # when applied to user code). Some attributes, such as _layers, would cause
+ # structural issues (_layers being the place where Layers assigned to tracked
+ # attributes are stored).
+ #
+ # Aside from these aesthetic and structural issues, useless dependencies on
+ # empty lists shouldn't cause issues; adding or removing them will not break
+ # checkpoints, but may cause "all Python objects matched" assertions to fail
+ # (in which case less strict assertions may be substituted if necessary).
+ @checkpointable.no_automatic_dependency_tracking
def _base_init(self, name=None):
# The following are implemented as property functions:
# self.trainable_weights
@@ -135,6 +149,7 @@ class Network(base_layer.Layer):
# restore operations when graph building.
self._in_progress_restore_finalizer = None
+ @checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
self._call_convention = base_layer.CallConvention.EXPLICIT_INPUTS_ARGUMENT
# Normalize and set self.inputs, self.outputs.
@@ -293,6 +308,7 @@ class Network(base_layer.Layer):
for layer in self._output_layers:
self.output_names.append(layer.name)
+ @checkpointable.no_automatic_dependency_tracking
def _init_subclassed_network(self, name=None):
self._base_init(name=name)
self._is_graph_network = False
@@ -302,8 +318,8 @@ class Network(base_layer.Layer):
else:
self._expects_training_arg = False
self._call_convention = self._determine_call_convention(call_argspec)
- self.outputs = None
- self.inputs = None
+ self.outputs = []
+ self.inputs = []
self.built = False
def _determine_call_convention(self, call_argspec):
@@ -362,10 +378,31 @@ class Network(base_layer.Layer):
self._track_checkpointable(
layer, name='layer-%d' % layer_index, overwrite=True)
+ def _no_dependency(self, value):
+ """Override to allow `Layer` to disable dependency tracking.
+
+ `CheckpointableBase` defines this method, whose semantics are "if a subclass
+ does dependency tracking, this method exempts `value`." Layer uses
+ `_no_dependency` to exempt some of its attribute assignments (conditional on
+ attribute assignment causing tracking in the subclass).
+
+ Args:
+ value: An object which will be assigned to an object attribute, whose
+ value should not be tracked.
+
+ Returns:
+ A wrapped object which, when assigned to an attribute, will not be
+ tracked (`value` will be stored in the attribute).
+ """
+ return data_structures.NoDependency(value)
+
def __setattr__(self, name, value):
- no_dependency = isinstance(value, checkpointable.NoDependency)
- if no_dependency:
- value = value.value
+ if not getattr(self, '_setattr_tracking', True):
+ super(Network, self).__setattr__(name, value)
+ return
+ no_dependency = isinstance(value, data_structures.NoDependency)
+ value = data_structures.sticky_attribute_assignment(
+ checkpointable=self, value=value, name=name)
if isinstance(value, (
base_layer.Layer,
Network,
@@ -377,7 +414,9 @@ class Network(base_layer.Layer):
'forgot to call `super(YourClass, self).__init__()`.'
' Always start with this line.')
if not is_graph_network:
- if value not in self._layers:
+ # We need to check object identity to avoid de-duplicating empty
+ # container types which compare equal.
+ if not any((layer is value for layer in self._layers)):
self._layers.append(value)
if hasattr(value, '_use_resource_variables'):
# In subclassed models, legacy layers (tf.layers) must always use
@@ -385,12 +424,6 @@ class Network(base_layer.Layer):
value._use_resource_variables = True
if (not no_dependency
and isinstance(value, checkpointable.CheckpointableBase)):
- # Layer (and therefore Network/Model) inherit from CheckpointableBase
- # rather than Checkpointable, which means there is no Checkpointable
- # __setattr__ override (it would be a performance issue for functional
- # layers). Therefore Model tracks Checkpointable objects itself.
- self._track_checkpointable(
- checkpointable=value, name=name, overwrite=True)
if ( # For subclassed models only, users may add extra weights/variables
# simply by assigning them to attributes.
not self._is_graph_network
@@ -493,7 +526,8 @@ class Network(base_layer.Layer):
@property
def layers(self):
- return self._layers
+ return checkpointable_layer_utils.filter_empty_layer_containers(
+ self._layers)
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.