aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-04-26 16:12:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-26 17:31:01 -0700
commit1e3e5d424eaa6332314f8ad1d54089eb0f9e02e7 (patch)
tree7ab6e50810f4dbd843d087cb4fa2e4288e4eafb9 /tensorflow/contrib/keras
parent7314ba8f5d8419892b23a5c89fd809c8d86fdcb8 (diff)
Refactor Keras layers to rely on core TF layers.
API change: for users of custom Keras layers built using `tf.contrib.keras`, the method `add_weight` of the Keras base layer has now a new API (synced with the main Keras GitHub repo). Change: 154366685
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/BUILD1
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py40
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology.py348
-rw-r--r--tensorflow/contrib/keras/python/keras/engine/topology_test.py4
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/advanced_activations.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional.py14
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py6
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core.py64
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/core_test.py13
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/embeddings.py2
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/local.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/merge.py1
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/normalization.py8
-rw-r--r--tensorflow/contrib/keras/python/keras/layers/recurrent.py20
-rw-r--r--tensorflow/contrib/keras/python/keras/models.py9
15 files changed, 199 insertions, 341 deletions
diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD
index 5166ba37a3..b1b8fc49b6 100644
--- a/tensorflow/contrib/keras/BUILD
+++ b/tensorflow/contrib/keras/BUILD
@@ -119,6 +119,7 @@ py_library(
"//tensorflow/python:gradients",
"//tensorflow/python:image_ops",
"//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
"//tensorflow/python:logging_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index e52b23843a..905ef13e14 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -21,7 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from collections import defaultdict
import json
import os
import warnings
@@ -245,17 +244,40 @@ def set_image_data_format(data_format):
def get_uid(prefix=''):
- global _GRAPH_UID_DICTS # pylint: disable=global-variable-not-assigned
- graph = ops.get_default_graph()
- if graph not in _GRAPH_UID_DICTS:
- _GRAPH_UID_DICTS[graph] = defaultdict(int)
- _GRAPH_UID_DICTS[graph][prefix] += 1
- return _GRAPH_UID_DICTS[graph][prefix]
+ """Associates a string prefix with an integer counter in a TensorFlow graph.
+
+ Arguments:
+ prefix: String prefix to index.
+
+ Returns:
+ Unique integer ID.
+
+ Example:
+
+ ```
+ >>> get_uid('dense')
+ 1
+ >>> get_uid('dense')
+ 2
+ ```
+ """
+ layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
+ if not layer_name_uids_collection:
+ layer_name_uids = {}
+ ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
+ else:
+ layer_name_uids = layer_name_uids_collection[0]
+ if prefix not in layer_name_uids:
+ layer_name_uids[prefix] = 1
+ else:
+ layer_name_uids[prefix] += 1
+ return layer_name_uids[prefix]
def reset_uids():
- global _GRAPH_UID_DICTS
- _GRAPH_UID_DICTS = {}
+ layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS')
+ if layer_name_uids_collection:
+ layer_name_uids_collection.pop()
def clear_session():
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py
index 7848e5982d..0336fc4bf4 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology.py
@@ -29,11 +29,12 @@ import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.contrib.keras.python.keras import backend as K
-from tensorflow.contrib.keras.python.keras import initializers
from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.util import tf_inspect
@@ -207,7 +208,7 @@ class Node(object):
}
-class Layer(object):
+class Layer(tf_base_layers.Layer):
"""Abstract base layer class.
# Properties
@@ -276,24 +277,6 @@ class Layer(object):
"""
def __init__(self, **kwargs):
- self.input_spec = None
- self.supports_masking = False
-
- # These properties will be set upon call of self.build()
- self._trainable_weights = []
- self._non_trainable_weights = []
- self._constraints = {} # dict {tensor: constraint instance}
- self._losses = []
- self._updates = []
- self._per_input_losses = {}
- self._per_input_updates = {}
- self._built = False
-
- # These lists will be filled via successive calls
- # to self._add_inbound_node().
- self.inbound_nodes = []
- self.outbound_nodes = []
-
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
# are only applicable to input layers: do not pass these keywords
@@ -306,18 +289,38 @@ class Layer(object):
'name',
'trainable',
'weights',
- 'input_dtype', # legacy
}
+ # Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
raise TypeError('Keyword argument not understood:', kwarg)
+
+ # Get layer name.
name = kwargs.get('name')
- if not name:
- prefix = self.__class__.__name__
- name = _to_snake_case(prefix) + '_' + str(K.get_uid(prefix))
- self.name = name
- self.trainable = kwargs.get('trainable', True)
+ # Get `trainable` status.
+ trainable = kwargs.get('trainable', True)
+
+ # Get `dtype`.
+ dtype = kwargs.get('dtype')
+ if dtype is None:
+ dtype = K.floatx()
+
+ # Call super, which will set all properties common to Keras layers
+ # and core TF layers.
+ super(Layer, self).__init__(name=name, dtype=dtype, trainable=trainable)
+
+ # Add properties that are Keras-only for now.
+ self.input_spec = None
+ self.supports_masking = False
+ self._constraints = {} # dict {tensor: constraint instance}
+
+ # These lists will be filled via successive calls
+ # to self._add_inbound_node().
+ self.inbound_nodes = []
+ self.outbound_nodes = []
+
+ # Manage input shape information if passed.
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer
# to insert before the current layer
@@ -331,36 +334,13 @@ class Layer(object):
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
self.batch_input_shape = batch_input_shape
- # Set dtype.
- dtype = kwargs.get('dtype')
- if dtype is None:
- dtype = kwargs.get('input_dtype')
- if dtype is None:
- dtype = K.floatx()
- self.dtype = dtype
-
+ # Manage initial weight values if passed.
if 'weights' in kwargs:
self._initial_weights = kwargs['weights']
else:
self._initial_weights = None
@property
- def losses(self):
- return self._losses
-
- @property
- def updates(self):
- return self._updates
-
- @property
- def built(self):
- return self._built
-
- @built.setter
- def built(self, value):
- self._built = value
-
- @property
def constraints(self):
return self._constraints
@@ -368,63 +348,37 @@ class Layer(object):
def constraints(self, constraints):
self._constraints = constraints
- @property
- def trainable_weights(self):
- trainable = getattr(self, 'trainable', True)
- if trainable:
- return self._trainable_weights
- else:
- return []
-
- @trainable_weights.setter
- def trainable_weights(self, weights):
- self._trainable_weights = weights
-
- @property
- def non_trainable_weights(self):
- trainable = getattr(self, 'trainable', True)
- if not trainable:
- return self._trainable_weights + self._non_trainable_weights
- else:
- return self._non_trainable_weights
-
- @non_trainable_weights.setter
- def non_trainable_weights(self, weights):
- self._non_trainable_weights = weights
-
def add_weight(self,
+ name,
shape,
- initializer,
- name=None,
- trainable=True,
+ dtype=None,
+ initializer=None,
regularizer=None,
+ trainable=True,
constraint=None):
"""Adds a weight variable to the layer.
Arguments:
+ name: String, the name for the weight variable.
shape: The shape tuple of the weight.
+ dtype: The dtype of the weight.
initializer: An Initializer instance (callable).
- name: String, the name for the weight variable.
+ regularizer: An optional Regularizer instance.
trainable: A boolean, whether the weight should
be trained via backprop or not (assuming
that the layer itself is also trainable).
- regularizer: An optional Regularizer instance.
constraint: An optional Constraint instance.
Returns:
The created weight variable.
"""
- shape = tuple(tensor_shape.TensorShape(shape).as_list())
- initializer = initializers.get(initializer)
- weight = K.variable(initializer(shape), dtype=K.floatx(), name=name)
- if regularizer is not None:
- self.add_loss(regularizer(weight))
+ if dtype is None:
+ dtype = K.floatx()
+ weight = self.add_variable(
+ name, shape, dtype=dtype,
+ initializer=initializer, regularizer=regularizer, trainable=trainable)
if constraint is not None:
self.constraints[weight] = constraint
- if trainable:
- self._trainable_weights.append(weight)
- else:
- self._non_trainable_weights.append(weight)
return weight
def assert_input_compatibility(self, inputs):
@@ -554,66 +508,46 @@ class Layer(object):
"""
if isinstance(inputs, list):
inputs = inputs[:]
+
+ # Raise exceptions in case the input is not compatible
+ # with the input_spec set at build time.
+ # TODO(fchollet): call after the layer is built, too.
+ self.assert_input_compatibility(inputs)
+
+ # Handle mask propagation.
+ previous_mask = _collect_previous_mask(inputs)
+ user_kwargs = copy.copy(kwargs)
+ if not _is_all_none(previous_mask):
+ # The previous layer generated a mask.
+ if 'mask' in tf_inspect.getargspec(self.call).args:
+ if 'mask' not in kwargs:
+ # If mask is explicitly passed to __call__,
+ # we should override the default mask.
+ kwargs['mask'] = previous_mask
+
+ # Actually call the layer (optionally building it).
+ output = super(Layer, self).__call__(inputs, **kwargs)
+
+ # Handle mask computation.
with K.name_scope(self.name):
- # Handle laying building (weight creating, input spec locking).
- if not self.built:
- # Raise exceptions in case the input is not compatible
- # with the input_spec specified in the layer constructor.
- self.assert_input_compatibility(inputs)
-
- # Collect input shapes to build layer.
- input_shapes = []
- for x_elem in _to_list(inputs):
- input_shapes.append(K.int_shape(x_elem))
- if len(input_shapes) == 1:
- self.build(input_shapes[0])
- else:
- self.build(input_shapes)
- self.built = True
-
- # Load weights that were specified at layer instantiation.
- if self._initial_weights is not None:
- self.set_weights(self._initial_weights)
-
- # Raise exceptions in case the input is not compatible
- # with the input_spec set at build time.
- self.assert_input_compatibility(inputs)
-
- # Handle mask propagation.
- previous_mask = _collect_previous_mask(inputs)
- user_kwargs = copy.copy(kwargs)
- if not _is_all_none(previous_mask):
- # The previous layer generated a mask.
- if 'mask' in tf_inspect.getargspec(self.call).args:
- if 'mask' not in kwargs:
- # If mask is explicitly passed to __call__,
- # we should override the default mask.
- kwargs['mask'] = previous_mask
-
- # Actually call the layer, collecting output(s), mask(s), and shape(s).
- output = self.call(inputs, **kwargs)
output_mask = self.compute_mask(inputs, previous_mask)
- # Add an inbound node to the layer, so that it keeps track
- # of the call and of all new variables created during the call.
- # This also updates the layer history of the output tensor(s).
- # If the input tensor(s) had not previous Keras history,
- # this does nothing.
- self._add_inbound_node(
- input_tensors=inputs,
- output_tensors=output,
- input_masks=previous_mask,
- output_masks=output_mask,
- arguments=user_kwargs)
-
- # Apply activity regularizer if any:
- if hasattr(
- self,
- 'activity_regularizer') and self.activity_regularizer is not None:
- regularization_losses = [
- self.activity_regularizer(x) for x in _to_list(output)
- ]
- self.add_loss(regularization_losses, _to_list(inputs))
+ # Add an inbound node to the layer, so that it keeps track
+ # of the call and of all new variables created during the call.
+ # This also updates the layer history of the output tensor(s).
+ # If the input tensor(s) had not previous Keras history,
+ # this does nothing.
+ self._add_inbound_node(
+ input_tensors=inputs,
+ output_tensors=output,
+ input_masks=previous_mask,
+ output_masks=output_mask,
+ arguments=user_kwargs)
+
+ # Optionally load weight values that were specified at layer instantiation.
+ if hasattr(self, '_initial_weights') and self._initial_weights is not None:
+ self.set_weights(self._initial_weights)
+ del self._initial_weights
return output
def _add_inbound_node(self,
@@ -959,14 +893,14 @@ class Layer(object):
@property
def input_shape(self):
- """Retrieves the input shape tuple(s) of a layer.
+ """Retrieves the input shape(s) of a layer.
Only applicable if the layer has exactly one inbound node,
i.e. if it is connected to one incoming layer.
Returns:
- Input shape tuple
- (or list of input shape tuples, one tuple per input tensor).
+ Input shape, as `TensorShape`
+ (or list of `TensorShape`, one tuple per input tensor).
Raises:
AttributeError: if the layer is connected to
@@ -997,14 +931,14 @@ class Layer(object):
@property
def output_shape(self):
- """Retrieves the output shape tuple(s) of a layer.
+ """Retrieves the output shape(s) of a layer.
Only applicable if the layer has one inbound node,
or if all inbound nodes have the same output shape.
Returns:
- Output shape tuple
- (or list of input shape tuples, one tuple per output tensor).
+ Output shape, as `TensorShape`
+ (or list of `TensorShape`, one tuple per output tensor).
Raises:
AttributeError: if the layer is connected to
@@ -1033,94 +967,6 @@ class Layer(object):
'Use `get_output_shape_at(node_index)` '
'instead.')
- def add_loss(self, losses, inputs=None):
- """Add losses to the layer.
-
- The loss may potentially be conditional on some inputs tensors,
- for instance activity losses are conditional on the layer's inputs.
-
- Arguments:
- losses: loss tensor or list of loss tensors
- to add to the layer.
- inputs: input tensor or list of inputs tensors to mark
- the losses as conditional on these inputs.
- If None is passed, the loss is assumed unconditional
- (e.g. L2 weight regularization, which only depends
- on the layer's weights variables, not on any inputs tensors).
- """
- if losses is None or losses == []: # pylint: disable=g-explicit-bool-comparison
- return
- # Update self.losses
- losses = _to_list(losses)
- if hasattr(self, '_losses'):
- self._losses += losses
- # Update self._per_input_updates
- if inputs == []: # pylint: disable=g-explicit-bool-comparison
- inputs = None
- if inputs is not None:
- inputs_hash = _object_list_uid(inputs)
- else:
- # Updates indexed by None are unconditional
- # rather than input-dependent
- inputs_hash = None
- if inputs_hash not in self._per_input_losses:
- self._per_input_losses[inputs_hash] = []
- self._per_input_losses[inputs_hash] += losses
-
- def add_update(self, updates, inputs=None):
- """Add updates to the layer.
-
- The updates may potentially be conditional on some inputs tensors,
- for instance batch norm updates are conditional on the layer's inputs.
-
- Arguments:
- updates: update op or list of update ops
- to add to the layer.
- inputs: input tensor or list of inputs tensors to mark
- the updates as conditional on these inputs.
- If None is passed, the updates are assumed unconditional.
- """
- if updates is None or updates == []: # pylint: disable=g-explicit-bool-comparison
- return
- # Update self.updates
- updates = _to_list(updates)
- if hasattr(self, '_updates'):
- self._updates += updates
- # Update self._per_input_updates
- if inputs == []: # pylint: disable=g-explicit-bool-comparison
- inputs = None
- if inputs is not None:
- inputs_hash = _object_list_uid(inputs)
- else:
- # Updates indexed by None are unconditional
- # rather than input-dependent
- inputs_hash = None
- if inputs_hash not in self._per_input_updates:
- self._per_input_updates[inputs_hash] = []
- self._per_input_updates[inputs_hash] += updates
-
- def get_updates_for(self, inputs):
- if inputs is not None:
- inputs_hash = _object_list_uid(inputs)
- else:
- inputs_hash = None
- if inputs_hash in self._per_input_updates:
- return self._per_input_updates[inputs_hash]
- return []
-
- def get_losses_for(self, inputs):
- if inputs is not None:
- inputs_hash = _object_list_uid(inputs)
- else:
- inputs_hash = None
- if inputs_hash in self._per_input_losses:
- return self._per_input_losses[inputs_hash]
- return []
-
- @property
- def weights(self):
- return self.trainable_weights + self.non_trainable_weights
-
def set_weights(self, weights):
"""Sets the weights of the layer, from Numpy arrays.
@@ -1254,9 +1100,12 @@ class InputLayer(Layer):
if not name:
prefix = 'input'
name = prefix + '_' + str(K.get_uid(prefix))
+ if not dtype:
+ if input_tensor is None:
+ dtype = K.floatx()
+ else:
+ dtype = K.dtype(input_tensor)
super(InputLayer, self).__init__(dtype=dtype, name=name)
-
- self.trainable = False
self.built = True
self.sparse = sparse
@@ -1284,15 +1133,7 @@ class InputLayer(Layer):
batch_input_shape = (batch_size,) + tuple(input_shape)
else:
batch_input_shape = tuple(batch_input_shape)
-
- if not dtype:
- if input_tensor is None:
- dtype = K.floatx()
- else:
- dtype = K.dtype(input_tensor)
-
self.batch_input_shape = batch_input_shape
- self.dtype = dtype
if input_tensor is None:
self.is_placeholder = True
@@ -1446,12 +1287,19 @@ class Container(Layer):
prefix = self.__class__.__name__.lower()
name = prefix + '_' + str(K.get_uid(prefix))
self.name = name
-
self.supports_masking = False
self.trainable = True
self._per_input_losses = {}
self._per_input_updates = {}
+ # The following properties are not actually used by Keras;
+ # they exist for compatibility with TF.
+ self._updates = []
+ self._scope = None
+ self._reuse = None
+ self._base_name = name
+ self._graph = ops.get_default_graph()
+
# Container-specific properties.
if isinstance(inputs, (list, tuple)):
self.inputs = list(inputs) # Tensor or list of tensors.
diff --git a/tensorflow/contrib/keras/python/keras/engine/topology_test.py b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
index eb095b14a9..531ed4be3e 100644
--- a/tensorflow/contrib/keras/python/keras/engine/topology_test.py
+++ b/tensorflow/contrib/keras/python/keras/engine/topology_test.py
@@ -490,8 +490,8 @@ class TopologyConstructionTest(test.TestCase):
m, n = model([j, k])
tf_model = keras.models.Model([j, k], [m, n])
- j_tf = array_ops.placeholder(dtype=dtypes.float32)
- k_tf = array_ops.placeholder(dtype=dtypes.float32)
+ j_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32))
+ k_tf = array_ops.placeholder(dtype=dtypes.float32, shape=(None, 32))
m_tf, n_tf = tf_model([j_tf, k_tf])
self.assertListEqual(m_tf.get_shape().as_list(), [None, 64])
self.assertListEqual(n_tf.get_shape().as_list(), [None, 5])
diff --git a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py
index b3abfc29d2..2c957ece44 100644
--- a/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py
+++ b/tensorflow/contrib/keras/python/keras/layers/advanced_activations.py
@@ -120,7 +120,7 @@ class PReLU(Layer):
param_shape[i - 1] = 1
self.param_broadcast[i - 1] = True
self.alpha = self.add_weight(
- param_shape,
+ shape=param_shape,
name='alpha',
initializer=self.alpha_initializer,
regularizer=self.alpha_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional.py b/tensorflow/contrib/keras/python/keras/layers/convolutional.py
index 38b8fe66a3..16f49c3390 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional.py
@@ -140,14 +140,14 @@ class _Conv(Layer):
kernel_shape = self.kernel_size + (input_dim, self.filters)
self.kernel = self.add_weight(
- kernel_shape,
+ shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (self.filters,),
+ shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
@@ -734,14 +734,14 @@ class Conv2DTranspose(Conv2D):
kernel_shape = self.kernel_size + (self.filters, input_dim)
self.kernel = self.add_weight(
- kernel_shape,
+ shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (self.filters,),
+ shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
@@ -949,13 +949,13 @@ class SeparableConv2D(Conv2D):
self.filters)
self.depthwise_kernel = self.add_weight(
- depthwise_kernel_shape,
+ shape=depthwise_kernel_shape,
initializer=self.depthwise_initializer,
name='depthwise_kernel',
regularizer=self.depthwise_regularizer,
constraint=self.depthwise_constraint)
self.pointwise_kernel = self.add_weight(
- pointwise_kernel_shape,
+ shape=pointwise_kernel_shape,
initializer=self.pointwise_initializer,
name='pointwise_kernel',
regularizer=self.pointwise_regularizer,
@@ -963,7 +963,7 @@ class SeparableConv2D(Conv2D):
if self.use_bias:
self.bias = self.add_weight(
- (self.filters,),
+ shape=(self.filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
index 4d8ef44da7..30325b7148 100644
--- a/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/convolutional_recurrent.py
@@ -369,20 +369,20 @@ class ConvLSTM2D(ConvRecurrent2D):
recurrent_kernel_shape = self.kernel_size + (self.filters, self.filters * 4)
self.kernel = self.add_weight(
- kernel_shape,
+ shape=kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
- recurrent_kernel_shape,
+ shape=recurrent_kernel_shape,
initializer=self.recurrent_initializer,
name='recurrent_kernel',
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (self.filters * 4,),
+ shape=(self.filters * 4,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/layers/core.py b/tensorflow/contrib/keras/python/keras/layers/core.py
index 32ada176a4..7a9e0d1736 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core.py
@@ -34,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.layers import core as tf_core_layers
from tensorflow.python.util import tf_inspect
@@ -643,7 +644,7 @@ class Lambda(Layer):
return cls(**config)
-class Dense(Layer):
+class Dense(tf_core_layers.Dense, Layer):
"""Just your regular densely-connected NN layer.
`Dense` implements the operation:
@@ -712,15 +713,20 @@ class Dense(Layer):
**kwargs):
if 'input_shape' not in kwargs and 'input_dim' in kwargs:
kwargs['input_shape'] = (kwargs.pop('input_dim'),)
- super(Dense, self).__init__(**kwargs)
- self.units = units
- self.activation = activations.get(activation)
- self.use_bias = use_bias
- self.kernel_initializer = initializers.get(kernel_initializer)
- self.bias_initializer = initializers.get(bias_initializer)
- self.kernel_regularizer = regularizers.get(kernel_regularizer)
- self.bias_regularizer = regularizers.get(bias_regularizer)
- self.activity_regularizer = regularizers.get(activity_regularizer)
+
+ # Inheritance call order:
+ # 1) tf.layers.Dense, 2) keras.layers.Layer, 3) tf.layers.Layer
+ super(Dense, self).__init__(
+ units,
+ activation=activations.get(activation),
+ use_bias=use_bias,
+ kernel_initializer=initializers.get(kernel_initializer),
+ bias_initializer=initializers.get(bias_initializer),
+ kernel_regularizer=regularizers.get(kernel_regularizer),
+ bias_regularizer=regularizers.get(bias_regularizer),
+ activity_regularizer=regularizers.get(activity_regularizer),
+ **kwargs)
+
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(min_ndim=2)
@@ -729,40 +735,12 @@ class Dense(Layer):
def build(self, input_shape):
assert len(input_shape) >= 2
input_dim = input_shape[-1]
-
- self.kernel = self.add_weight(
- (input_dim, self.units),
- initializer=self.kernel_initializer,
- name='kernel',
- regularizer=self.kernel_regularizer,
- constraint=self.kernel_constraint)
- if self.use_bias:
- self.bias = self.add_weight(
- (self.units,),
- initializer=self.bias_initializer,
- name='bias',
- regularizer=self.bias_regularizer,
- constraint=self.bias_constraint)
- else:
- self.bias = None
+ super(Dense, self).build(input_shape)
self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})
- self.built = True
-
- def call(self, inputs):
- output = K.dot(inputs, self.kernel)
- if self.use_bias:
- output = K.bias_add(output, self.bias)
- if self.activation is not None:
- output = self.activation(output)
- return output
-
- def _compute_output_shape(self, input_shape):
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- assert input_shape and len(input_shape) >= 2
- assert input_shape[-1]
- output_shape = list(input_shape)
- output_shape[-1] = self.units
- return tensor_shape.TensorShape(output_shape)
+ if self.kernel_constraint:
+ self.constraints[self.kernel] = self.kernel_constraint
+ if self.use_bias and self.bias_constraint:
+ self.constraints[self.bias] = self.bias_constraint
def get_config(self):
config = {
diff --git a/tensorflow/contrib/keras/python/keras/layers/core_test.py b/tensorflow/contrib/keras/python/keras/layers/core_test.py
index d7aa8413bb..7066af0ef6 100644
--- a/tensorflow/contrib/keras/python/keras/layers/core_test.py
+++ b/tensorflow/contrib/keras/python/keras/layers/core_test.py
@@ -165,24 +165,23 @@ class CoreLayersTest(test.TestCase):
3,
kernel_regularizer=keras.regularizers.l1(0.01),
bias_regularizer='l1',
- activity_regularizer='l2')
- layer.build((None, 4))
- assert len(layer.losses) == 2
+ activity_regularizer='l2',
+ name='dense_reg')
layer(keras.backend.variable(np.ones((2, 4))))
- assert len(layer.losses) == 3
+ self.assertEqual(3, len(layer.losses))
# Test constraints
with self.test_session():
layer = keras.layers.Dense(
3, kernel_constraint='max_norm', bias_constraint='max_norm')
- layer.build((None, 4))
- assert len(layer.constraints) == 2
+ layer(keras.backend.variable(np.ones((2, 4))))
+ self.assertEqual(2, len(layer.constraints))
def test_activity_regularization(self):
with self.test_session():
layer = keras.layers.ActivityRegularization(l1=0.1)
layer(keras.backend.variable(np.ones((2, 4))))
- assert len(layer.losses) == 1
+ self.assertEqual(1, len(layer.losses))
if __name__ == '__main__':
diff --git a/tensorflow/contrib/keras/python/keras/layers/embeddings.py b/tensorflow/contrib/keras/python/keras/layers/embeddings.py
index 12a2ce39eb..bc0bae67d0 100644
--- a/tensorflow/contrib/keras/python/keras/layers/embeddings.py
+++ b/tensorflow/contrib/keras/python/keras/layers/embeddings.py
@@ -116,7 +116,7 @@ class Embedding(Layer):
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
self.embeddings = self.add_weight(
- (self.input_dim, self.output_dim),
+ shape=(self.input_dim, self.output_dim),
initializer=self.embeddings_initializer,
name='embeddings',
regularizer=self.embeddings_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/layers/local.py b/tensorflow/contrib/keras/python/keras/layers/local.py
index d96ccc4a63..863674c1cb 100644
--- a/tensorflow/contrib/keras/python/keras/layers/local.py
+++ b/tensorflow/contrib/keras/python/keras/layers/local.py
@@ -130,14 +130,14 @@ class LocallyConnected1D(Layer):
self.kernel_shape = (output_length, self.kernel_size[0] * input_dim,
self.filters)
self.kernel = self.add_weight(
- self.kernel_shape,
+ shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (output_length, self.filters),
+ shape=(output_length, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
@@ -340,14 +340,14 @@ class LocallyConnected2D(Layer):
output_row * output_col,
self.kernel_size[0] * self.kernel_size[1] * input_filter, self.filters)
self.kernel = self.add_weight(
- self.kernel_shape,
+ shape=self.kernel_shape,
initializer=self.kernel_initializer,
name='kernel',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (output_row, output_col, self.filters),
+ shape=(output_row, output_col, self.filters),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/layers/merge.py b/tensorflow/contrib/keras/python/keras/layers/merge.py
index 7c6482d0de..25921979bd 100644
--- a/tensorflow/contrib/keras/python/keras/layers/merge.py
+++ b/tensorflow/contrib/keras/python/keras/layers/merge.py
@@ -87,6 +87,7 @@ class _Merge(Layer):
raise ValueError('A merge layer should be called '
'on a list of at least 2 inputs. '
'Got ' + str(len(input_shape)) + ' inputs.')
+ input_shape = [tensor_shape.TensorShape(s).as_list() for s in input_shape]
batch_sizes = [s[0] for s in input_shape if s is not None]
batch_sizes = set(batch_sizes)
batch_sizes -= set([None])
diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization.py b/tensorflow/contrib/keras/python/keras/layers/normalization.py
index 9a0340aeaf..df77401aee 100644
--- a/tensorflow/contrib/keras/python/keras/layers/normalization.py
+++ b/tensorflow/contrib/keras/python/keras/layers/normalization.py
@@ -116,7 +116,7 @@ class BatchNormalization(Layer):
if self.scale:
self.gamma = self.add_weight(
- shape,
+ shape=shape,
name='gamma',
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
@@ -125,7 +125,7 @@ class BatchNormalization(Layer):
self.gamma = None
if self.center:
self.beta = self.add_weight(
- shape,
+ shape=shape,
name='beta',
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
@@ -133,12 +133,12 @@ class BatchNormalization(Layer):
else:
self.beta = None
self.moving_mean = self.add_weight(
- shape,
+ shape=shape,
name='moving_mean',
initializer=self.moving_mean_initializer,
trainable=False)
self.moving_variance = self.add_weight(
- shape,
+ shape=shape,
name='moving_variance',
initializer=self.moving_variance_initializer,
trainable=False)
diff --git a/tensorflow/contrib/keras/python/keras/layers/recurrent.py b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
index 1ea1cb22d9..e608921add 100644
--- a/tensorflow/contrib/keras/python/keras/layers/recurrent.py
+++ b/tensorflow/contrib/keras/python/keras/layers/recurrent.py
@@ -493,20 +493,20 @@ class SimpleRNN(Recurrent):
self.reset_states()
self.kernel = self.add_weight(
- (self.input_dim, self.units),
+ shape=(self.input_dim, self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
- (self.units, self.units),
+ shape=(self.units, self.units),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
self.bias = self.add_weight(
- (self.units,),
+ shape=(self.units,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
@@ -723,13 +723,13 @@ class GRU(Recurrent):
self.reset_states()
self.kernel = self.add_weight(
- (self.input_dim, self.units * 3),
+ shape=(self.input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
- (self.units, self.units * 3),
+ shape=(self.units, self.units * 3),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
@@ -737,9 +737,9 @@ class GRU(Recurrent):
if self.use_bias:
self.bias = self.add_weight(
- (self.units * 3,),
+ shape=(self.units * 3,),
name='bias',
- initializer='zero',
+ initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
@@ -1039,13 +1039,13 @@ class LSTM(Recurrent):
self.reset_states()
self.kernel = self.add_weight(
- (self.input_dim, self.units * 4),
+ shape=(self.input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
- (self.units, self.units * 4),
+ shape=(self.units, self.units * 4),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
@@ -1053,7 +1053,7 @@ class LSTM(Recurrent):
if self.use_bias:
self.bias = self.add_weight(
- (self.units * 4,),
+ shape=(self.units * 4,),
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py
index eb0996fa12..52456a4bb5 100644
--- a/tensorflow/contrib/keras/python/keras/models.py
+++ b/tensorflow/contrib/keras/python/keras/models.py
@@ -35,6 +35,7 @@ from tensorflow.contrib.keras.python.keras.engine.topology import Input
from tensorflow.contrib.keras.python.keras.engine.topology import Layer
from tensorflow.contrib.keras.python.keras.engine.training import Model
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
+from tensorflow.python.framework import ops
# pylint: disable=g-import-not-at-top
@@ -420,6 +421,14 @@ class Sequential(Model):
name = prefix + str(K.get_uid(prefix))
self.name = name
+ # The following properties are not actually used by Keras;
+ # they exist for compatibility with TF's variable scoping mechanism.
+ self._updates = []
+ self._scope = None
+ self._reuse = None
+ self._base_name = name
+ self._graph = ops.get_default_graph()
+
# Add to the model any layers passed to the constructor.
if layers:
for layer in layers: