diff options
author | Francois Chollet <fchollet@google.com> | 2017-04-26 16:12:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-04-26 17:31:01 -0700 |
commit | 1e3e5d424eaa6332314f8ad1d54089eb0f9e02e7 (patch) | |
tree | 7ab6e50810f4dbd843d087cb4fa2e4288e4eafb9 /tensorflow/contrib/keras | |
parent | 7314ba8f5d8419892b23a5c89fd809c8d86fdcb8 (diff) |
Refactor Keras layers to rely on core TF layers.
API change: for users of custom Keras layers built using `tf.contrib.keras`, the method `add_weight` of the Keras base layer has now a new API (synced with the main Keras GitHub repo).
Change: 154366685
Diffstat (limited to 'tensorflow/contrib/keras')
15 files changed, 199 insertions, 341 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 5166ba37a3..b1b8fc49b6 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -119,6 +119,7 @@ py_library( "//tensorflow/python:gradients", "//tensorflow/python:image_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:layers", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index e52b23843a..905ef13e14 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -21,7 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import defaultdict import json import os import warnings @@ -245,17 +244,40 @@ def set_image_data_format(data_format): def get_uid(prefix=''): - global _GRAPH_UID_DICTS # pylint: disable=global-variable-not-assigned - graph = ops.get_default_graph() - if graph not in _GRAPH_UID_DICTS: - _GRAPH_UID_DICTS[graph] = defaultdict(int) - _GRAPH_UID_DICTS[graph][prefix] += 1 - return _GRAPH_UID_DICTS[graph][prefix] + """Associates a string prefix with an integer counter in a TensorFlow graph. + + Arguments: + prefix: String prefix to index. + + Returns: + Unique integer ID. + + Example: + + ``` + >>> get_uid('dense') + 1 + >>> get_uid('dense') + 2 + ``` + """ + layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS') + if not layer_name_uids_collection: + layer_name_uids = {} + ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids) + else: + layer_name_uids = layer_name_uids_collection[0] + if prefix not in layer_name_uids: + layer_name_uids[prefix] = 1 + else: + layer_name_uids[prefix] += 1 + return layer_name_uids[prefix] def reset_uids(): - global _GRAPH_UID_DICTS - _GRAPH_UID_DICTS = {} + layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS') + if layer_name_uids_collection: + layer_name_uids_collection.pop() def clear_session(): diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 7848e5982d..0336fc4bf4 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -29,11 +29,12 @@ import numpy as np from six.moves import zip # pylint: disable=redefined-builtin from tensorflow.contrib.keras.python.keras import backend as K -from tensorflow.contrib.keras.python.keras import initializers from tensorflow.contrib.keras.python.keras.utils import conv_utils from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.util import tf_inspect @@ -207,7 +208,7 @@ class Node(object): } -class Layer(object): +class Layer(tf_base_layers.Layer): """Abstract base layer class. # Properties @@ -276,24 +277,6 @@ class Layer(object): """ def __init__(self, **kwargs): - self.input_spec = None - self.supports_masking = False - - # These properties will be set upon call of self.build() - self._trainable_weights = [] - self._non_trainable_weights = [] - self._constraints = {} # dict {tensor: constraint instance} - self._losses = [] - self._updates = [] - self._per_input_losses = {} - self._per_input_updates = {} - self._built = False - - # These lists will be filled via successive calls - # to self._add_inbound_node(). - self.inbound_nodes = [] - self.outbound_nodes = [] - # These properties should be set by the user via keyword arguments. # note that 'dtype', 'input_shape' and 'batch_input_shape' # are only applicable to input layers: do not pass these keywords @@ -306,18 +289,38 @@ class Layer(object): 'name', 'trainable', 'weights', - 'input_dtype', # legacy } + # Validate optional keyword arguments. for kwarg in kwargs: if kwarg not in allowed_kwargs: raise TypeError('Keyword argument not understood:', kwarg) + + # Get layer name. name = kwargs.get('name') - if not name: - prefix = self.__class__.__name__ - name = _to_snake_case(prefix) + '_' + str(K.get_uid(prefix)) - self.name = name - self.trainable = kwargs.get('trainable', True) + # Get `trainable` status. + trainable = kwargs.get('trainable', True) + + # Get `dtype`. + dtype = kwargs.get('dtype') + if dtype is None: + dtype = K.floatx() + + # Call super, which will set all properties common to Keras layers + # and core TF layers. + super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable) + + # Add properties that are Keras-only for now. + self.input_spec = None + self.supports_masking = False + self._constraints = {} # dict {tensor: constraint instance} + + # These lists will be filled via successive calls + # to self._add_inbound_node(). + self.inbound_nodes = [] + self.outbound_nodes = [] + + # Manage input shape information if passed. if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: # In this case we will later create an input layer # to insert before the current layer @@ -331,36 +334,13 @@ class Layer(object): batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) self.batch_input_shape = batch_input_shape - # Set dtype. - dtype = kwargs.get('dtype') - if dtype is None: - dtype = kwargs.get('input_dtype') - if dtype is None: - dtype = K.floatx() - self.dtype = dtype - + # Manage initial weight values if passed. if 'weights' in kwargs: self._initial_weights = kwargs['weights'] else: self._initial_weights = None @property - def losses(self): - return self._losses - - @property - def updates(self): - return self._updates - - @property - def built(self): - return self._built - - @built.setter - def built(self, value): - self._built = value - - @property def constraints(self): return self._constraints @@ -368,63 +348,37 @@ class Layer(object): def constraints(self, constraints): self._constraints = constraints - @property - def trainable_weights(self): - trainable = getattr(self, 'trainable', True) - if trainable: - return self._trainable_weights - else: - return [] - - @trainable_weights.setter - def trainable_weights(self, weights): - self._trainable_weights = weights - - @property - def non_trainable_weights(self): - trainable = getattr(self, 'trainable', True) - if not trainable: - return self._trainable_weights + self._non_trainable_weights - else: - return self._non_trainable_weights - - @non_trainable_weights.setter - def non_trainable_weights(self, weights): - self._non_trainable_weights = weights - def add_weight(self, + name, shape, - initializer, - name=None, - trainable=True, + dtype=None, + initializer=None, regularizer=None, + trainable=True, constraint=None): """Adds a weight variable to the layer. Arguments: + name: String, the name for the weight variable. shape: The shape tuple of the weight. + dtype: The dtype of the weight. initializer: An Initializer instance (callable). - name: String, the name for the weight variable. + regularizer: An optional Regularizer instance. trainable: A boolean, whether the weight should be trained via backprop or not (assuming that the layer itself is also trainable). - regularizer: An optional Regularizer instance. constraint: An optional Constraint instance. Returns: The created weight variable. """ - shape = tuple(tensor_shape.TensorShape(shape).as_list()) - initializer = initializers.get(initializer) - weight = K.variable(initializer(shape), dtype=K.floatx(), name=name) - if regularizer is not None: - self.add_loss(regularizer(weight)) + if dtype is None: + dtype = K.floatx() + weight = self.add_variable( + name, shape, dtype=dtype, + initializer=initializer, regularizer=regularizer, trainable=trainable) if constraint is not None: self.constraints[weight] = constraint - if trainable: - self._trainable_weights.append(weight) - else: - self._non_trainable_weights.append(weight) return weight def assert_input_compatibility(self, inputs): @@ -554,66 +508,46 @@ class Layer(object): """ if isinstance(inputs, list): inputs = inputs[:] + + # Raise exceptions in case the input is not compatible + # with the input_spec set at build time. + # TODO(fchollet): call after the layer is built, too. + self.assert_input_compatibility(inputs) + + # Handle mask propagation. + previous_mask = _collect_previous_mask(inputs) + user_kwargs = copy.copy(kwargs) + if not _is_all_none(previous_mask): + # The previous layer generated a mask. + if 'mask' in tf_inspect.getargspec(self.call).args: + if 'mask' not in kwargs: + # If mask is explicitly passed to __call__, + # we should override the default mask. + kwargs['mask'] = previous_mask + + # Actually call the layer (optionally building it). + output = super(Layer, self).__call__(inputs, **kwargs) + + # Handle mask computation. with K.name_scope(self.name): - # Handle laying building (weight creating, input spec locking). - if not self.built: - # Raise exceptions in case the input is not compatible - # with the input_spec specified in the layer constructor. - self.assert_input_compatibility(inputs) - - # Collect input shapes to build layer. - input_shapes = [] - for x_elem in _to_list(inputs): - input_shapes.append(K.int_shape(x_elem)) - if len(input_shapes) == 1: - self.build(input_shapes[0]) - else: - self.build(input_shapes) - self.built = True - - # Load weights that were specified at layer instantiation. - if self._initial_weights is not None: - self.set_weights(self._initial_weights) - - # Raise exceptions in case the input is not compatible - # with the input_spec set at build time. - self.assert_input_compatibility(inputs) - - # Handle mask propagation. - previous_mask = _collect_previous_mask(inputs) - user_kwargs = copy.copy(kwargs) - if not _is_all_none(previous_mask): - # The previous layer generated a mask. - if 'mask' in tf_inspect.getargspec(self.call).args: - if 'mask' not in kwargs: - # If mask is explicitly passed to __call__, - # we should override the default mask. - kwargs['mask'] = previous_mask - - # Actually call the layer, collecting output(s), mask(s), and shape(s). - output = self.call(inputs, **kwargs) output_mask = self.compute_mask(inputs, previous_mask) - # Add an inbound node to the layer, so that it keeps track - # of the call and of all new variables created during the call. - # This also updates the layer history of the output tensor(s). - # If the input tensor(s) had not previous Keras history, - # this does nothing. - self._add_inbound_node( - input_tensors=inputs, - output_tensors=output, - input_masks=previous_mask, - output_masks=output_mask, - arguments=user_kwargs) - - # Apply activity regularizer if any: - if hasattr( - self, - 'activity_regularizer') and self.activity_regularizer is not None: - regularization_losses = [ - self.activity_regularizer(x) for x in _to_list(output) - ] - self.add_loss(regularization_losses, _to_list(inputs)) + # Add an inbound node to the layer, so that it keeps track + # of the call and of all new variables created during the call. + # This also updates the layer history of the output tensor(s). + # If the input tensor(s) had not previous Keras history, + # this does nothing. + self._add_inbound_node( + input_tensors=inputs, + output_tensors=output, + input_masks=previous_mask, + output_masks=output_mask, + arguments=user_kwargs) + + # Optionally load weight values that were specified at layer instantiation. + if hasattr(self, '_initial_weights') and self._initial_weights is not None: + self.set_weights(self._initial_weights) + del self._initial_weights return output def _add_inbound_node(self, @@ -959,14 +893,14 @@ class Layer(object): @property def input_shape(self): - """Retrieves the input shape tuple(s) of a layer. + """Retrieves the input shape(s) of a layer. Only applicable if the layer has exactly one inbound node, i.e. if it is connected to one incoming layer. Returns: - Input shape tuple - (or list of input shape tuples, one tuple per input tensor). + Input shape, as `TensorShape` + (or list of `TensorShape`, one tuple per input tensor). Raises: AttributeError: if the layer is connected to @@ -997,14 +931,14 @@ class Layer(object): @property def output_shape(self): - """Retrieves the output shape tuple(s) of a layer. + """Retrieves the output shape(s) of a layer. Only applicable if the layer has one inbound node, or if all inbound nodes have the same output shape. Returns: - Output shape tuple - (or list of input shape tuples, one tuple per output tensor). + Output shape, as `TensorShape` + (or list of `TensorShape`, one tuple per output tensor). Raises: AttributeError: if the layer is connected to @@ -1033,94 +967,6 @@ class Layer(object): 'Use `get_output_shape_at(node_index)` ' 'instead.') - def add_loss(self, losses, inputs=None): - """Add losses to the layer. - - The loss may potentially be conditional on some inputs tensors, - for instance activity losses are conditional on the layer's inputs. - - Arguments: - losses: loss tensor or list of loss tensors - to add to the layer. - inputs: input tensor or list of inputs tensors to mark - the losses as conditional on these inputs. - If None is passed, the loss is assumed unconditional - (e.g. L2 weight regularization, which only depends - on the layer's weights variables, not on any inputs tensors). - """ - if losses is None or losses == []: # pylint: disable=g-explicit-bool-comparison - return - # Update self.losses - losses = _to_list(losses) - if hasattr(self, '_losses'): - self._losses += losses - # Update self._per_input_updates - if inputs == []: # pylint: disable=g-explicit-bool-comparison - inputs = None - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - # Updates indexed by None are unconditional - # rather than input-dependent - inputs_hash = None - if inputs_hash not in self._per_input_losses: - self._per_input_losses[inputs_hash] = [] - self._per_input_losses[inputs_hash] += losses - - def add_update(self, updates, inputs=None): - """Add updates to the layer. - - The updates may potentially be conditional on some inputs tensors, - for instance batch norm updates are conditional on the layer's inputs. - - Arguments: - updates: update op or list of update ops - to add to the layer. - inputs: input tensor or list of inputs tensors to mark - the updates as conditional on these inputs. - If None is passed, the updates are assumed unconditional. - """ - if updates is None or updates == []: # pylint: disable=g-explicit-bool-comparison - return - # Update self.updates - updates = _to_list(updates) - if hasattr(self, '_updates'): - self._updates += updates - # Update self._per_input_updates - if inputs == []: # pylint: disable=g-explicit-bool-comparison - inputs = None - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - # Updates indexed by None are unconditional - # rather than input-dependent - inputs_hash = None - if inputs_hash not in self._per_input_updates: - self._per_input_updates[inputs_hash] = [] - self._per_input_updates[inputs_hash] += updates - - def get_updates_for(self, inputs): - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - inputs_hash = None - if inputs_hash in self._per_input_updates: - return self._per_input_updates[inputs_hash] - return [] - - def get_losses_for(self, inputs): - if inputs is not None: - inputs_hash = _object_list_uid(inputs) - else: - inputs_hash = None - if inputs_hash in self._per_input_losses: - return self._per_input_losses[inputs_hash] - return [] - - @property - def weights(self): - return self.trainable_weights + self.non_trainable_weights - def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. @@ -1254,9 +1100,12 @@ class InputLayer(Layer): if not name: prefix = 'input' name = prefix + '_' + str(K.get_uid(prefix)) + if not dtype: + if input_tensor is None: + dtype = K.floatx() + else: + dtype = K.dtype(input_tensor) super(InputLayer, self).__init__(dtype=dtype, name=name) - - self.trainable = False self.built = True self.sparse = sparse @@ -1284,15 +1133,7 @@ class InputLayer(Layer): batch_input_shape = (batch_size,) + tuple(input_shape) else: batch_input_shape = tuple(batch_input_shape) - - if not dtype: - if input_tensor is None: - dtype = K.floatx() - else: - dtype = K.dtype(input_tensor) - self.batch_input_shape = batch_input_shape - self.dtype = dtype if input_tensor is None: self.is_placeholder = True @@ -1446,12 +1287,19 @@ class Container(Layer): prefix = self.__class__.__name__.lower() name = prefix + '_' + str(K.get_uid(prefix)) self.name = name - self.supports_masking = False self.trainable = True self._per_input_losses = {} self._per_input_updates = {} + # The following properties are not actually used by Keras; + # they exist for compatibility with TF. + self._updates = [] + self._scope = None + self._reuse = None + self._base_name = name + self._graph = ops.get_default_graph() + # Container-specific properties. if isinstance(inputs, (list, tuple)): self.inputs = list(inputs) # Tensor or list of tensors. diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py index eb095b14a9..531ed4be3e 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py @@ -490,8 +490,8 @@ class TopologyConstructionTest(test.TestCase): m, n = model([j, k]) tf_model = keras.models.Model([j, k], [m, n]) - j_tf = array_ops.placeholder(dtype=dtypes.float32) - k_tf = array_ops.placeholder(dtype=dtypes.float32) + j_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32)) + k_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32)) m_tf, n_tf = tf_model([j_tf, k_tf]) self.assertListEqual(m_tf.get_shape().as_list(), [None, 64]) self.assertListEqual(n_tf.get_shape().as_list(), [None, 5]) diff --git a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py index b3abfc29d2..2c957ece44 100644 --- a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py +++ b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py @@ -120,7 +120,7 @@ class PReLU(Layer): param_shape[i - 1] = 1 self.param_broadcast[i - 1] = True self.alpha = self.add_weight( - param_shape, + shape=param_shape, name='alpha', initializer=self.alpha_initializer, regularizer=self.alpha_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/contrib/keras/python/keras/layers/convolutional.py index 38b8fe66a3..16f49c3390 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional.py @@ -140,14 +140,14 @@ class _Conv(Layer): kernel_shape = self.kernel_size + (input_dim, self.filters) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -734,14 +734,14 @@ class Conv2DTranspose(Conv2D): kernel_shape = self.kernel_size + (self.filters, input_dim) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -949,13 +949,13 @@ class SeparableConv2D(Conv2D): self.filters) self.depthwise_kernel = self.add_weight( - depthwise_kernel_shape, + shape=depthwise_kernel_shape, initializer=self.depthwise_initializer, name='depthwise_kernel', regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint) self.pointwise_kernel = self.add_weight( - pointwise_kernel_shape, + shape=pointwise_kernel_shape, initializer=self.pointwise_initializer, name='pointwise_kernel', regularizer=self.pointwise_regularizer, @@ -963,7 +963,7 @@ class SeparableConv2D(Conv2D): if self.use_bias: self.bias = self.add_weight( - (self.filters,), + shape=(self.filters,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py index 4d8ef44da7..30325b7148 100644 --- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py @@ -369,20 +369,20 @@ class ConvLSTM2D(ConvRecurrent2D): recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4) self.kernel = self.add_weight( - kernel_shape, + shape=kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - recurrent_kernel_shape, + shape=recurrent_kernel_shape, initializer=self.recurrent_initializer, name='recurrent_kernel', regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) if self.use_bias: self.bias = self.add_weight( - (self.filters * 4,), + shape=(self.filters * 4,), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py index 32ada176a4..7a9e0d1736 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core.py +++ b/tensorflow/contrib/keras/python/keras/layers/core.py @@ -34,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import core as tf_core_layers from tensorflow.python.util import tf_inspect @@ -643,7 +644,7 @@ class Lambda(Layer): return cls(**config) -class Dense(Layer): +class Dense(tf_core_layers.Dense, Layer): """Just your regular densely-connected NN layer. `Dense` implements the operation: @@ -712,15 +713,20 @@ class Dense(Layer): **kwargs): if 'input_shape' not in kwargs and 'input_dim' in kwargs: kwargs['input_shape'] = (kwargs.pop('input_dim'),) - super(Dense, self).__init__(**kwargs) - self.units = units - self.activation = activations.get(activation) - self.use_bias = use_bias - self.kernel_initializer = initializers.get(kernel_initializer) - self.bias_initializer = initializers.get(bias_initializer) - self.kernel_regularizer = regularizers.get(kernel_regularizer) - self.bias_regularizer = regularizers.get(bias_regularizer) - self.activity_regularizer = regularizers.get(activity_regularizer) + + # Inheritance call order: + # 1) tf.layers.Dense, 2) keras.layers.Layer, 3) tf.layers.Layer + super(Dense, self).__init__( + units, + activation=activations.get(activation), + use_bias=use_bias, + kernel_initializer=initializers.get(kernel_initializer), + bias_initializer=initializers.get(bias_initializer), + kernel_regularizer=regularizers.get(kernel_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + **kwargs) + self.kernel_constraint = constraints.get(kernel_constraint) self.bias_constraint = constraints.get(bias_constraint) self.input_spec = InputSpec(min_ndim=2) @@ -729,40 +735,12 @@ class Dense(Layer): def build(self, input_shape): assert len(input_shape) >= 2 input_dim = input_shape[-1] - - self.kernel = self.add_weight( - (input_dim, self.units), - initializer=self.kernel_initializer, - name='kernel', - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint) - if self.use_bias: - self.bias = self.add_weight( - (self.units,), - initializer=self.bias_initializer, - name='bias', - regularizer=self.bias_regularizer, - constraint=self.bias_constraint) - else: - self.bias = None + super(Dense, self).build(input_shape) self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim}) - self.built = True - - def call(self, inputs): - output = K.dot(inputs, self.kernel) - if self.use_bias: - output = K.bias_add(output, self.bias) - if self.activation is not None: - output = self.activation(output) - return output - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).as_list() - assert input_shape and len(input_shape) >= 2 - assert input_shape[-1] - output_shape = list(input_shape) - output_shape[-1] = self.units - return tensor_shape.TensorShape(output_shape) + if self.kernel_constraint: + self.constraints[self.kernel] = self.kernel_constraint + if self.use_bias and self.bias_constraint: + self.constraints[self.bias] = self.bias_constraint def get_config(self): config = { diff --git a/tensorflow/contrib/keras/python/keras/layers/core_test.py b/tensorflow/contrib/keras/python/keras/layers/core_test.py index d7aa8413bb..7066af0ef6 100644 --- a/tensorflow/contrib/keras/python/keras/layers/core_test.py +++ b/tensorflow/contrib/keras/python/keras/layers/core_test.py @@ -165,24 +165,23 @@ class CoreLayersTest(test.TestCase): 3, kernel_regularizer=keras.regularizers.l1(0.01), bias_regularizer='l1', - activity_regularizer='l2') - layer.build((None, 4)) - assert len(layer.losses) == 2 + activity_regularizer='l2', + name='dense_reg') layer(keras.backend.variable(np.ones((2, 4)))) - assert len(layer.losses) == 3 + self.assertEqual(3, len(layer.losses)) # Test constraints with self.test_session(): layer = keras.layers.Dense( 3, kernel_constraint='max_norm', bias_constraint='max_norm') - layer.build((None, 4)) - assert len(layer.constraints) == 2 + layer(keras.backend.variable(np.ones((2, 4)))) + self.assertEqual(2, len(layer.constraints)) def test_activity_regularization(self): with self.test_session(): layer = keras.layers.ActivityRegularization(l1=0.1) layer(keras.backend.variable(np.ones((2, 4)))) - assert len(layer.losses) == 1 + self.assertEqual(1, len(layer.losses)) if __name__ == '__main__': diff --git a/tensorflow/contrib/keras/python/keras/layers/embeddings.py b/tensorflow/contrib/keras/python/keras/layers/embeddings.py index 12a2ce39eb..bc0bae67d0 100644 --- a/tensorflow/contrib/keras/python/keras/layers/embeddings.py +++ b/tensorflow/contrib/keras/python/keras/layers/embeddings.py @@ -116,7 +116,7 @@ class Embedding(Layer): def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape).as_list() self.embeddings = self.add_weight( - (self.input_dim, self.output_dim), + shape=(self.input_dim, self.output_dim), initializer=self.embeddings_initializer, name='embeddings', regularizer=self.embeddings_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/contrib/keras/python/keras/layers/local.py index d96ccc4a63..863674c1cb 100644 --- a/tensorflow/contrib/keras/python/keras/layers/local.py +++ b/tensorflow/contrib/keras/python/keras/layers/local.py @@ -130,14 +130,14 @@ class LocallyConnected1D(Layer): self.kernel_shape = (output_length, self.kernel_size[0] * input_dim, self.filters) self.kernel = self.add_weight( - self.kernel_shape, + shape=self.kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (output_length, self.filters), + shape=(output_length, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, @@ -340,14 +340,14 @@ class LocallyConnected2D(Layer): output_row * output_col, self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters) self.kernel = self.add_weight( - self.kernel_shape, + shape=self.kernel_shape, initializer=self.kernel_initializer, name='kernel', regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) if self.use_bias: self.bias = self.add_weight( - (output_row, output_col, self.filters), + shape=(output_row, output_col, self.filters), initializer=self.bias_initializer, name='bias', regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py index 7c6482d0de..25921979bd 100644 --- a/tensorflow/contrib/keras/python/keras/layers/merge.py +++ b/tensorflow/contrib/keras/python/keras/layers/merge.py @@ -87,6 +87,7 @@ class _Merge(Layer): raise ValueError('A merge layer should be called ' 'on a list of at least 2 inputs. ' 'Got ' + str(len(input_shape)) + ' inputs.') + input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape] batch_sizes = [s[0] for s in input_shape if s is not None] batch_sizes = set(batch_sizes) batch_sizes -= set([None]) diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/contrib/keras/python/keras/layers/normalization.py index 9a0340aeaf..df77401aee 100644 --- a/tensorflow/contrib/keras/python/keras/layers/normalization.py +++ b/tensorflow/contrib/keras/python/keras/layers/normalization.py @@ -116,7 +116,7 @@ class BatchNormalization(Layer): if self.scale: self.gamma = self.add_weight( - shape, + shape=shape, name='gamma', initializer=self.gamma_initializer, regularizer=self.gamma_regularizer, @@ -125,7 +125,7 @@ class BatchNormalization(Layer): self.gamma = None if self.center: self.beta = self.add_weight( - shape, + shape=shape, name='beta', initializer=self.beta_initializer, regularizer=self.beta_regularizer, @@ -133,12 +133,12 @@ class BatchNormalization(Layer): else: self.beta = None self.moving_mean = self.add_weight( - shape, + shape=shape, name='moving_mean', initializer=self.moving_mean_initializer, trainable=False) self.moving_variance = self.add_weight( - shape, + shape=shape, name='moving_variance', initializer=self.moving_variance_initializer, trainable=False) diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py index 1ea1cb22d9..e608921add 100644 --- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py +++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py @@ -493,20 +493,20 @@ class SimpleRNN(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units), + shape=(self.input_dim, self.units), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units), + shape=(self.units, self.units), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint) if self.use_bias: self.bias = self.add_weight( - (self.units,), + shape=(self.units,), name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, @@ -723,13 +723,13 @@ class GRU(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units * 3), + shape=(self.input_dim, self.units * 3), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units * 3), + shape=(self.units, self.units * 3), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, @@ -737,9 +737,9 @@ class GRU(Recurrent): if self.use_bias: self.bias = self.add_weight( - (self.units * 3,), + shape=(self.units * 3,), name='bias', - initializer='zero', + initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint) else: @@ -1039,13 +1039,13 @@ class LSTM(Recurrent): self.reset_states() self.kernel = self.add_weight( - (self.input_dim, self.units * 4), + shape=(self.input_dim, self.units * 4), name='kernel', initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint) self.recurrent_kernel = self.add_weight( - (self.units, self.units * 4), + shape=(self.units, self.units * 4), name='recurrent_kernel', initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, @@ -1053,7 +1053,7 @@ class LSTM(Recurrent): if self.use_bias: self.bias = self.add_weight( - (self.units * 4,), + shape=(self.units * 4,), name='bias', initializer=self.bias_initializer, regularizer=self.bias_regularizer, diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index eb0996fa12..52456a4bb5 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -35,6 +35,7 @@ from tensorflow.contrib.keras.python.keras.engine.topology import Input from tensorflow.contrib.keras.python.keras.engine.topology import Layer from tensorflow.contrib.keras.python.keras.engine.training import Model from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.framework import ops # pylint: disable=g-import-not-at-top @@ -420,6 +421,14 @@ class Sequential(Model): name = prefix + str(K.get_uid(prefix)) self.name = name + # The following properties are not actually used by Keras; + # they exist for compatibility with TF's variable scoping mechanism. + self._updates = [] + self._scope = None + self._reuse = None + self._base_name = name + self._graph = ops.get_default_graph() + # Add to the model any layers passed to the constructor. if layers: for layer in layers: |