diff options
Diffstat (limited to 'tensorflow/python/keras/engine/base_layer.py')
-rw-r--r-- | tensorflow/python/keras/engine/base_layer.py | 111 |
1 files changed, 82 insertions, 29 deletions
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 4814275fd5..b41f6ee03b 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -116,6 +116,7 @@ class Layer(checkpointable.CheckpointableBase): constraints on inputs that can be accepted by the layer. """ + @checkpointable.no_automatic_dependency_tracking def __init__(self, trainable=True, name=None, dtype=None, **kwargs): # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' @@ -217,7 +218,7 @@ class Layer(checkpointable.CheckpointableBase): @activity_regularizer.setter def activity_regularizer(self, regularizer): """Optional regularizer function for the output of this layer.""" - self._activity_regularizer = regularizer + self._activity_regularizer = self._no_dependency(regularizer) @property def trainable_weights(self): @@ -459,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase): """Alias for `add_weight`.""" return self.add_weight(*args, **kwargs) - def add_weight(self, name, shape, + def add_weight(self, + name, + shape, dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, constraint=None, partitioner=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, getter=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -481,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). partitioner: Partitioner to be passed to the `Checkpointable` API. use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. getter: Variable getter argument to be passed to the `Checkpointable` API. Returns: @@ -495,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. - ValueError: When giving unsupported dtype and no initializer. + ValueError: When giving unsupported dtype and no initializer or when + trainable has been set to True with synchronization set as `ON_READ`. """ if dtype is None: dtype = self.dtype or backend.floatx() @@ -504,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase): regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + # Initialize variable when no initializer provided if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer @@ -531,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase): constraint=constraint, trainable=trainable and self.trainable, partitioner=partitioner, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) if regularizer is not None: # TODO(fchollet): in the future, this should be handled at the @@ -654,11 +685,12 @@ class Layer(checkpointable.CheckpointableBase): # Handle Keras mask propagation from previous layer to current layer. previous_mask = None - if (not hasattr(self, '_compute_previous_mask') or - self._compute_previous_mask): + if build_graph and (not hasattr(self, '_compute_previous_mask') or + self._compute_previous_mask): previous_mask = collect_previous_mask(inputs) if not hasattr(self, '_call_fn_args'): - self._call_fn_args = function_utils.fn_args(self.call) + self._call_fn_args = self._no_dependency( + function_utils.fn_args(self.call)) if ('mask' in self._call_fn_args and 'mask' not in kwargs and not generic_utils.is_all_none(previous_mask)): # The previous layer generated a mask, and mask was not explicitly pass @@ -691,9 +723,18 @@ class Layer(checkpointable.CheckpointableBase): self._dtype = input_list[0].dtype.base_dtype.name except AttributeError: pass - if all(hasattr(x, 'get_shape') for x in input_list): - input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) - self.build(input_shapes) + + if all(hasattr(x, 'shape') for x in input_list): + input_shapes = nest.map_structure(lambda x: x.shape, inputs) + + if (not hasattr(self, '_is_graph_network') or + self.__class__.__name__ == 'Sequential'): + # Only if self is a layer or an instance of a sequential model do we + # need to build it. + self.build(input_shapes) + # We must set self.built since user defined build functions are not + # constrained to set self.built. + self.built = True # Check input assumptions set after layer building, e.g. input shape. if build_graph or in_deferred_mode: @@ -709,7 +750,7 @@ class Layer(checkpointable.CheckpointableBase): # Deferred mode behavior: use `compute_output_shape` to # infer the number of outputs of the layer and their shapes. if input_shapes is None: - input_shapes = nest.map_structure(lambda x: x.get_shape(), inputs) + input_shapes = nest.map_structure(lambda x: x.shape, inputs) output_shapes = self.compute_output_shape(input_shapes) output_shapes = nest.flatten(output_shapes) @@ -729,8 +770,6 @@ class Layer(checkpointable.CheckpointableBase): if in_deferred_mode or build_graph and have_all_keras_metadata(inputs): inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) - - self.built = True if context.executing_eagerly(): return outputs @@ -1293,7 +1332,7 @@ class Layer(checkpointable.CheckpointableBase): ', but the layer isn\'t built. ' 'You can build it manually via: `' + self.name + '.build(batch_input_shape)`.') - weight_shapes = [w.get_shape().as_list() for w in self.weights] + weight_shapes = [w.shape.as_list() for w in self.weights] return int(sum([np.prod(w) for w in weight_shapes])) @property @@ -1376,7 +1415,7 @@ class Layer(checkpointable.CheckpointableBase): if (spec.ndim is not None or spec.min_ndim is not None or spec.max_ndim is not None): - if x.get_shape().ndims is None: + if x.shape.ndims is None: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'its rank is undefined, but the layer requires a ' @@ -1384,29 +1423,29 @@ class Layer(checkpointable.CheckpointableBase): # Check ndim. if spec.ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim != spec.ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'expected ndim=' + str(spec.ndim) + ', found ndim=' + str(ndim) + '. Full shape received: ' + - str(x.get_shape().as_list())) + str(x.shape.as_list())) if spec.max_ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim is not None and ndim > spec.max_ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' 'expected max_ndim=' + str(spec.max_ndim) + ', found ndim=' + str(ndim)) if spec.min_ndim is not None: - ndim = x.get_shape().ndims + ndim = x.shape.ndims if ndim is not None and ndim < spec.min_ndim: raise ValueError('Input ' + str(input_index) + ' of layer ' + self.name + ' is incompatible with the layer: ' ': expected min_ndim=' + str(spec.min_ndim) + ', found ndim=' + str(ndim) + '. Full shape received: ' + - str(x.get_shape().as_list())) + str(x.shape.as_list())) # Check dtype. if spec.dtype is not None: if x.dtype != spec.dtype: @@ -1416,7 +1455,7 @@ class Layer(checkpointable.CheckpointableBase): ', found dtype=' + str(x.dtype)) # Check specific shape axes. if spec.axes: - shape = x.get_shape().as_list() + shape = x.shape.as_list() if shape is not None: for axis, value in spec.axes.items(): if hasattr(value, 'value'): @@ -1429,7 +1468,7 @@ class Layer(checkpointable.CheckpointableBase): ' but received input with shape ' + str(shape)) # Check shape. if spec.shape is not None: - shape = x.get_shape().as_list() + shape = x.shape.as_list() if shape is not None: for spec_dim, dim in zip(spec.shape, shape): if spec_dim is not None and dim is not None: @@ -1704,12 +1743,12 @@ class DeferredTensor(object): def __str__(self): return "DeferredTensor('%s', shape=%s, dtype=%s)" % (self.name, - self.get_shape(), + self.shape, self.dtype.name) def __repr__(self): return "<DeferredTensor '%s' shape=%s dtype=%s>" % (self.name, - self.get_shape(), + self.shape, self.dtype.name) @@ -1804,11 +1843,13 @@ def make_variable(name, dtype=dtypes.float32, initializer=None, partition_info=None, - trainable=True, + trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): # pylint: disable=unused-argument """Temporary util to create a variable (relies on `variable_scope.variable`). @@ -1834,11 +1875,21 @@ def make_variable(name, or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. caching_device: Passed to `vs.variable`. validate_shape: Passed to `vs.variable`. constraint: Constraint instance (callable). use_resource: Whether to use a `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. partitioner: Not handled at this time. Returns: @@ -1870,5 +1921,7 @@ def make_variable(name, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) return v |