aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/engine/base_layer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/engine/base_layer.py')
-rw-r--r--tensorflow/python/keras/engine/base_layer.py111
1 files changed, 82 insertions, 29 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 4814275fd5..b41f6ee03b 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase):
constraints on inputs that can be accepted by the layer.
"""
+ @checkpointable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
@@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase):
@activity_regularizer.setter
def activity_regularizer(self, regularizer):
"""Optional regularizer function for the output of this layer."""
- self._activity_regularizer = regularizer
+ self._activity_regularizer = self._no_dependency(regularizer)
@property
def trainable_weights(self):
@@ -459,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
- def add_weight(self, name, shape,
+ def add_weight(self,
+ name,
+ shape,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
constraint=None,
partitioner=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
getter=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -481,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
getter: Variable getter argument to be passed to the `Checkpointable` API.
Returns:
@@ -495,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
- ValueError: When giving unsupported dtype and no initializer.
+ ValueError: When giving unsupported dtype and no initializer or when
+ trainable has been set to True with synchronization set as `ON_READ`.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
@@ -504,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
# Initialize variable when no initializer provided
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
@@ -531,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase):
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
@@ -654,11 +685,12 @@ class Layer(checkpointable.CheckpointableBase):
# Handle Keras mask propagation from previous layer to current layer.
previous_mask = None
- if (not hasattr(self, '_compute_previous_mask') or
- self._compute_previous_mask):
+ if build_graph and (not hasattr(self, '_compute_previous_mask') or
+ self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
- self._call_fn_args = function_utils.fn_args(self.call)
+ self._call_fn_args = self._no_dependency(
+ function_utils.fn_args(self.call))
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
@@ -691,9 +723,18 @@ class Layer(checkpointable.CheckpointableBase):
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
- if all(hasattr(x, 'get_shape') for x in input_list):
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
- self.build(input_shapes)
+
+ if all(hasattr(x, 'shape') for x in input_list):
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
+
+ if (not hasattr(self, '_is_graph_network') or
+ self.__class__.__name__ == 'Sequential'):
+ # Only if self is a layer or an instance of a sequential model do we
+ # need to build it.
+ self.build(input_shapes)
+ # We must set self.built since user defined build functions are not
+ # constrained to set self.built.
+ self.built = True
# Check input assumptions set after layer building, e.g. input shape.
if build_graph or in_deferred_mode:
@@ -709,7 +750,7 @@ class Layer(checkpointable.CheckpointableBase):
# Deferred mode behavior: use `compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
if input_shapes is None:
- input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs)
+ input_shapes = nest.map_structure(lambda x: x.shape, inputs)
output_shapes = self.compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
@@ -729,8 +770,6 @@ class Layer(checkpointable.CheckpointableBase):
if in_deferred_mode or build_graph and have_all_keras_metadata(inputs):
inputs, outputs = self._set_connectivity_metadata_(
inputs, outputs, args, kwargs)
-
- self.built = True
if context.executing_eagerly():
return outputs
@@ -1293,7 +1332,7 @@ class Layer(checkpointable.CheckpointableBase):
', but the layer isn\'t built. '
'You can build it manually via: `' + self.name +
'.build(batch_input_shape)`.')
- weight_shapes = [w.get_shape().as_list() for w in self.weights]
+ weight_shapes = [w.shape.as_list() for w in self.weights]
return int(sum([np.prod(w) for w in weight_shapes]))
@property
@@ -1376,7 +1415,7 @@ class Layer(checkpointable.CheckpointableBase):
if (spec.ndim is not None or
spec.min_ndim is not None or
spec.max_ndim is not None):
- if x.get_shape().ndims is None:
+ if x.shape.ndims is None:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'its rank is undefined, but the layer requires a '
@@ -1384,29 +1423,29 @@ class Layer(checkpointable.CheckpointableBase):
# Check ndim.
if spec.ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim != spec.ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected ndim=' + str(spec.ndim) + ', found ndim=' +
str(ndim) + '. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
if spec.max_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim > spec.max_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
'expected max_ndim=' + str(spec.max_ndim) +
', found ndim=' + str(ndim))
if spec.min_ndim is not None:
- ndim = x.get_shape().ndims
+ ndim = x.shape.ndims
if ndim is not None and ndim < spec.min_ndim:
raise ValueError('Input ' + str(input_index) + ' of layer ' +
self.name + ' is incompatible with the layer: '
': expected min_ndim=' + str(spec.min_ndim) +
', found ndim=' + str(ndim) +
'. Full shape received: ' +
- str(x.get_shape().as_list()))
+ str(x.shape.as_list()))
# Check dtype.
if spec.dtype is not None:
if x.dtype != spec.dtype:
@@ -1416,7 +1455,7 @@ class Layer(checkpointable.CheckpointableBase):
', found dtype=' + str(x.dtype))
# Check specific shape axes.
if spec.axes:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for axis, value in spec.axes.items():
if hasattr(value, 'value'):
@@ -1429,7 +1468,7 @@ class Layer(checkpointable.CheckpointableBase):
' but received input with shape ' + str(shape))
# Check shape.
if spec.shape is not None:
- shape = x.get_shape().as_list()
+ shape = x.shape.as_list()
if shape is not None:
for spec_dim, dim in zip(spec.shape, shape):
if spec_dim is not None and dim is not None:
@@ -1704,12 +1743,12 @@ class DeferredTensor(object):
def __str__(self):
return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
def __repr__(self):
return "<DeferredTensor '%s' shape=%s dtype=%s>" % (self.name,
- self.get_shape(),
+ self.shape,
self.dtype.name)
@@ -1804,11 +1843,13 @@ def make_variable(name,
dtype=dtypes.float32,
initializer=None,
partition_info=None,
- trainable=True,
+ trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1834,11 +1875,21 @@ def make_variable(name,
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
caching_device: Passed to `vs.variable`.
validate_shape: Passed to `vs.variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: Not handled at this time.
Returns:
@@ -1870,5 +1921,7 @@ def make_variable(name,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
return v