aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/_impl/keras/layers/recurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/_impl/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py2383
1 files changed, 574 insertions, 1809 deletions
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index 2bc74d5f80..139523403c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -1,4 +1,4 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,209 +29,99 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
-from tensorflow.python.platform import tf_logging as logging
-class StackedRNNCells(Layer):
- """Wrapper allowing a stack of RNN cells to behave as a single cell.
+# pylint: disable=access-member-before-definition
- Used to implement efficient stacked RNNs.
+
+def _time_distributed_dense(x,
+ w,
+ b=None,
+ dropout=None,
+ input_dim=None,
+ output_dim=None,
+ timesteps=None,
+ training=None):
+ """Apply `y . w + b` for every temporal slice y of x.
Arguments:
- cells: List of RNN cell instances.
+ x: input tensor.
+ w: weight matrix.
+ b: optional bias vector.
+ dropout: whether to apply dropout (same dropout mask
+ for every temporal slice of the input).
+ input_dim: integer; optional dimensionality of the input.
+ output_dim: integer; optional dimensionality of the output.
+ timesteps: integer; optional number of timesteps.
+ training: training phase tensor or boolean.
+
+ Returns:
+ Output tensor.
+ """
+ if not input_dim:
+ input_dim = K.shape(x)[2]
+ if not timesteps:
+ timesteps = K.shape(x)[1]
+ if not output_dim:
+ output_dim = K.shape(w)[1]
+
+ if dropout is not None and 0. < dropout < 1.:
+ # apply the same dropout pattern at every timestep
+ ones = K.ones_like(K.reshape(x[:, 0, :], (-1, input_dim)))
+ dropout_matrix = K.dropout(ones, dropout)
+ expanded_dropout_matrix = K.repeat(dropout_matrix, timesteps)
+ x = K.in_train_phase(x * expanded_dropout_matrix, x, training=training)
+
+ # collapse time dimension and batch dimension together
+ x = K.reshape(x, (-1, input_dim))
+ x = K.dot(x, w)
+ if b is not None:
+ x = K.bias_add(x, b)
+ # reshape to 3D tensor
+ if K.backend() == 'tensorflow':
+ x = K.reshape(x, K.stack([-1, timesteps, output_dim]))
+ x.set_shape([None, None, output_dim])
+ else:
+ x = K.reshape(x, (-1, timesteps, output_dim))
+ return x
- Examples:
- ```python
- cells = [
- keras.layers.LSTMCell(output_dim),
- keras.layers.LSTMCell(output_dim),
- keras.layers.LSTMCell(output_dim),
- ]
-
- inputs = keras.Input((timesteps, input_dim))
- x = keras.layers.RNN(cells)(inputs)
- ```
- """
+class Recurrent(Layer):
+ """Abstract base class for recurrent layers.
- def __init__(self, cells, **kwargs):
- for cell in cells:
- if not hasattr(cell, 'call'):
- raise ValueError('All cells must have a `call` method. '
- 'received cells:', cells)
- if not hasattr(cell, 'state_size'):
- raise ValueError('All cells must have a '
- '`state_size` attribute. '
- 'received cells:', cells)
- self.cells = cells
- super(StackedRNNCells, self).__init__(**kwargs)
-
- @property
- def state_size(self):
- # States are a flat list
- # in reverse order of the cell stack.
- # This allows to preserve the requirement
- # `stack.state_size[0] == output_dim`.
- # e.g. states of a 2-layer LSTM would be
- # `[h2, c2, h1, c1]`
- # (assuming one LSTM has states [h, c])
- state_size = []
- for cell in self.cells[::-1]:
- if hasattr(cell.state_size, '__len__'):
- state_size += list(cell.state_size)
- else:
- state_size.append(cell.state_size)
- return tuple(state_size)
-
- def call(self, inputs, states, **kwargs):
- # Recover per-cell states.
- nested_states = []
- for cell in self.cells[::-1]:
- if hasattr(cell.state_size, '__len__'):
- nested_states.append(states[:len(cell.state_size)])
- states = states[len(cell.state_size):]
- else:
- nested_states.append([states[0]])
- states = states[1:]
- nested_states = nested_states[::-1]
-
- # Call the cells in order and store the returned states.
- new_nested_states = []
- for cell, states in zip(self.cells, nested_states):
- inputs, states = cell.call(inputs, states, **kwargs)
- new_nested_states.append(states)
-
- # Format the new states as a flat list
- # in reverse cell order.
- states = []
- for cell_states in new_nested_states[::-1]:
- states += cell_states
- return inputs, states
+ Do not use in a model -- it's not a valid layer!
+ Use its children classes `LSTM`, `GRU` and `SimpleRNN` instead.
- def build(self, input_shape):
- for cell in self.cells:
- if isinstance(cell, Layer):
- cell.build(input_shape)
- if hasattr(cell.state_size, '__len__'):
- output_dim = cell.state_size[0]
- else:
- output_dim = cell.state_size
- input_shape = (input_shape[0], input_shape[1], output_dim)
- self.built = True
+ All recurrent layers (`LSTM`, `GRU`, `SimpleRNN`) also
+ follow the specifications of this class and accept
+ the keyword arguments listed below.
- def get_config(self):
- cells = []
- for cell in self.cells:
- cells.append({
- 'class_name': cell.__class__.__name__,
- 'config': cell.get_config()
- })
- config = {'cells': cells}
- base_config = super(StackedRNNCells, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
+ Example:
- @classmethod
- def from_config(cls, config, custom_objects=None):
- from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
- cells = []
- for cell_config in config.pop('cells'):
- cells.append(
- deserialize_layer(cell_config, custom_objects=custom_objects))
- return cls(cells, **config)
-
- @property
- def trainable_weights(self):
- if not self.trainable:
- return []
- weights = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- weights += cell.trainable_weights
- return weights
-
- @property
- def non_trainable_weights(self):
- weights = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- weights += cell.non_trainable_weights
- if not self.trainable:
- trainable_weights = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- trainable_weights += cell.trainable_weights
- return trainable_weights + weights
- return weights
-
- def get_weights(self):
- """Retrieves the weights of the model.
-
- Returns:
- A flat list of Numpy arrays.
- """
- weights = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- weights += cell.weights
- return K.batch_get_value(weights)
-
- def set_weights(self, weights):
- """Sets the weights of the model.
-
- Arguments:
- weights: A list of Numpy arrays with shapes and types matching
- the output of `model.get_weights()`.
- """
- tuples = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- num_param = len(cell.weights)
- weights = weights[:num_param]
- for sw, w in zip(cell.weights, weights):
- tuples.append((sw, w))
- weights = weights[num_param:]
- K.batch_set_value(tuples)
-
- @property
- def losses(self):
- losses = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- cell_losses = cell.losses
- losses += cell_losses
- return losses
-
- def get_losses_for(self, inputs=None):
- losses = []
- for cell in self.cells:
- if isinstance(cell, Layer):
- cell_losses = cell.get_losses_for(inputs)
- losses += cell_losses
- return losses
-
-
-class RNN(Layer):
- """Base class for recurrent layers.
+ ```python
+ # as the first layer in a Sequential model
+ model = Sequential()
+ model.add(LSTM(32, input_shape=(10, 64)))
+ # now model.output_shape == (None, 32)
+ # note: `None` is the batch dimension.
+
+ # for subsequent layers, no need to specify the input size:
+ model.add(LSTM(16))
+
+ # to stack recurrent layers, you must use return_sequences=True
+ # on any recurrent layer that feeds into another recurrent layer.
+ # note that you only need to specify the input size on the first layer.
+ model = Sequential()
+ model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
+ model.add(LSTM(32, return_sequences=True))
+ model.add(LSTM(10))
+ ```
Arguments:
- cell: A RNN cell instance. A RNN cell is a class that has:
- - a `call(input_at_t, states_at_t)` method, returning
- `(output_at_t, states_at_t_plus_1)`. The call method of the
- cell can also take the optional argument `constants`, see
- section "Note on passing external constants" below.
- - a `state_size` attribute. This can be a single integer
- (single state) in which case it is
- the size of the recurrent state
- (which should be the same as the size of the cell output).
- This can also be a list/tuple of integers
- (one size per state). In this case, the first entry
- (`state_size[0]`) should be the same as
- the size of the cell output.
- It is also possible for `cell` to be a list of RNN cell instances,
- in which cases the cells get stacked on after the other in the RNN,
- implementing an efficient stacked RNN.
- return_sequences: Boolean. Whether to return the last output.
+ weights: list of Numpy arrays to set as initial weights.
+ The list should have 3 elements, of shapes:
+ `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
+ return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
in addition to the output.
@@ -247,9 +137,21 @@ class RNN(Layer):
Unrolling can speed-up a RNN,
although it tends to be more memory-intensive.
Unrolling is only suitable for short sequences.
+ implementation: one of {0, 1, or 2}.
+ If set to 0, the RNN will use
+ an implementation that uses fewer, larger matrix products,
+ thus running faster on CPU but consuming more memory.
+ If set to 1, the RNN will use more matrix products,
+ but smaller ones, thus running slower
+ (may actually be faster on GPU) while consuming less memory.
+ If set to 2 (LSTM/GRU only),
+ the RNN will combine the input gate,
+ the forget gate and the output gate into a single matrix,
+ enabling more time-efficient parallelization on the GPU.
+ Note: RNN dropout must be shared for all gates,
+ resulting in a slightly reduced regularization.
input_dim: dimensionality of the input (integer).
- This argument (or alternatively,
- the keyword argument `input_shape`)
+ This argument (or alternatively, the keyword argument `input_shape`)
is required when using this layer as the first layer in a model.
input_length: Length of input sequences, to be specified
when it is constant.
@@ -261,7 +163,7 @@ class RNN(Layer):
at the level of the first layer
(e.g. via the `input_shape` argument)
- Input shape:
+ Input shape:s
3D tensor with shape `(batch_size, timesteps, input_dim)`,
(Optional) 2D tensors with shape `(batch_size, output_dim)`.
@@ -276,7 +178,7 @@ class RNN(Layer):
# Masking
This layer supports masking for input data with a variable number
of timesteps. To introduce masks to your data,
- use an [Embedding](embeddings.md) layer with the `mask_zero` parameter
+ use an `Embedding` layer with the `mask_zero` parameter
set to `True`.
# Note on using statefulness in RNNs
@@ -310,128 +212,42 @@ class RNN(Layer):
calling `reset_states` with the keyword argument `states`. The value of
`states` should be a numpy array or list of numpy arrays representing
the initial state of the RNN layer.
-
- # Note on passing external constants to RNNs
- You can pass "external" constants to the cell using the `constants`
- keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
- requires that the `cell.call` method accepts the same keyword argument
- `constants`. Such constants can be used to condition the cell
- transformation on additional static inputs (not changing over time),
- a.k.a. an attention mechanism.
-
- Examples:
-
- ```python
- # First, let's define a RNN Cell, as a layer subclass.
-
- class MinimalRNNCell(keras.layers.Layer):
-
- def __init__(self, units, **kwargs):
- self.units = units
- self.state_size = units
- super(MinimalRNNCell, self).__init__(**kwargs)
-
- def build(self, input_shape):
- self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
- initializer='uniform',
- name='kernel')
- self.recurrent_kernel = self.add_weight(
- shape=(self.units, self.units),
- initializer='uniform',
- name='recurrent_kernel')
- self.built = True
-
- def call(self, inputs, states):
- prev_output = states[0]
- h = K.dot(inputs, self.kernel)
- output = h + K.dot(prev_output, self.recurrent_kernel)
- return output, [output]
-
- # Let's use this cell in a RNN layer:
-
- cell = MinimalRNNCell(32)
- x = keras.Input((None, 5))
- layer = RNN(cell)
- y = layer(x)
-
- # Here's how to use the cell to build a stacked RNN:
-
- cells = [MinimalRNNCell(32), MinimalRNNCell(64)]
- x = keras.Input((None, 5))
- layer = RNN(cells)
- y = layer(x)
- ```
"""
def __init__(self,
- cell,
return_sequences=False,
return_state=False,
go_backwards=False,
stateful=False,
unroll=False,
- activity_regularizer=None,
+ implementation=0,
**kwargs):
- if isinstance(cell, (list, tuple)):
- cell = StackedRNNCells(cell)
- if not hasattr(cell, 'call'):
- raise ValueError('`cell` should have a `call` method. '
- 'The RNN was passed:', cell)
- if not hasattr(cell, 'state_size'):
- raise ValueError('The RNN cell should have '
- 'an attribute `state_size` '
- '(tuple of integers, '
- 'one integer per RNN state).')
- super(RNN, self).__init__(
- activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
- self.cell = cell
+ super(Recurrent, self).__init__(**kwargs)
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
-
+ self.implementation = implementation
self.supports_masking = True
self.input_spec = [InputSpec(ndim=3)]
self.state_spec = None
- self._states = None
- self.constants_spec = None
- self._num_constants = None
-
- @property
- def states(self):
- if self._states is None:
- if isinstance(self.cell.state_size, int):
- num_states = 1
- else:
- num_states = len(self.cell.state_size)
- return [None for _ in range(num_states)]
- return self._states
-
- @states.setter
- def states(self, states):
- self._states = states
+ self.dropout = 0
+ self.recurrent_dropout = 0
def _compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
input_shape = tensor_shape.TensorShape(input_shape).as_list()
-
- if hasattr(self.cell.state_size, '__len__'):
- output_dim = self.cell.state_size[0]
- else:
- output_dim = self.cell.state_size
-
if self.return_sequences:
- output_shape = (input_shape[0], input_shape[1], output_dim)
+ output_shape = (input_shape[0], input_shape[1], self.units)
else:
- output_shape = (input_shape[0], output_dim)
+ output_shape = (input_shape[0], self.units)
if self.return_state:
- state_shape = [(input_shape[0], output_dim) for _ in self.states]
- output_shape = [output_shape] + state_shape
- else:
- output_shape = output_shape
+ state_shape = [tensor_shape.TensorShape(
+ (input_shape[0], self.units)) for _ in self.states]
+ return [tensor_shape.TensorShape(output_shape)] + state_shape
return tensor_shape.TensorShape(output_shape)
def compute_mask(self, inputs, mask):
@@ -441,123 +257,82 @@ class RNN(Layer):
if self.return_state:
state_mask = [None for _ in self.states]
return [output_mask] + state_mask
- else:
- return output_mask
-
- def build(self, input_shape):
- # Note input_shape will be list of shapes of initial states and
- # constants if these are passed in __call__.
- if self._num_constants is not None:
- constants_shape = input_shape[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
- else:
- constants_shape = None
-
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ return output_mask
- batch_size = input_shape[0] if self.stateful else None
- input_dim = input_shape[-1]
- self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))
-
- # allow cell (if layer) to build before we set or validate state_spec
- if isinstance(self.cell, Layer):
- step_input_shape = (input_shape[0],) + input_shape[2:]
- if constants_shape is not None:
- self.cell.build([step_input_shape] + constants_shape)
- else:
- self.cell.build(step_input_shape)
+ def step(self, inputs, states):
+ raise NotImplementedError
- # set or validate state_spec
- if hasattr(self.cell.state_size, '__len__'):
- state_size = list(self.cell.state_size)
- else:
- state_size = [self.cell.state_size]
-
- if self.state_spec is not None:
- # initial_state was passed in call, check compatibility
- if [spec.shape[-1] for spec in self.state_spec] != state_size:
- raise ValueError(
- 'An initial_state was passed that is not compatible with '
- '`cell.state_size`. Received `state_spec`={}; '
- 'However `cell.state_size` is '
- '{}'.format(self.state_spec, self.cell.state_size))
- else:
- self.state_spec = [InputSpec(shape=(None, dim)) for dim in state_size]
- if self.stateful:
- self.reset_states()
+ def get_constants(self, inputs, training=None):
+ return []
def get_initial_state(self, inputs):
# build an all-zero tensor of shape (samples, output_dim)
initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
initial_state = K.expand_dims(initial_state) # (samples, 1)
- if hasattr(self.cell.state_size, '__len__'):
- return [K.tile(initial_state, [1, dim]) for dim in self.cell.state_size]
- else:
- return [K.tile(initial_state, [1, self.cell.state_size])]
+ initial_state = K.tile(initial_state, [1,
+ self.units]) # (samples, output_dim)
+ initial_state = [initial_state for _ in range(len(self.states))]
+ return initial_state
- def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
- inputs, initial_state, constants = self._standardize_args(
- inputs, initial_state, constants)
+ def preprocess_input(self, inputs, training=None):
+ return inputs
- if initial_state is None and constants is None:
- return super(RNN, self).__call__(inputs, **kwargs)
+ def __call__(self, inputs, initial_state=None, **kwargs):
+ if (isinstance(inputs, (list, tuple)) and
+ len(inputs) > 1
+ and initial_state is None):
+ initial_state = inputs[1:]
+ inputs = inputs[0]
- # If any of `initial_state` or `constants` are specified and are Keras
- # tensors, then add them to the inputs and temporarily modify the
- # input_spec to include them.
+ # If `initial_state` is specified,
+ # and if it a Keras tensor,
+ # then add it to the inputs and temporarily
+ # modify the input spec to include the state.
+ if initial_state is None:
+ return super(Recurrent, self).__call__(inputs, **kwargs)
- additional_inputs = []
- additional_specs = []
- if initial_state is not None:
- kwargs['initial_state'] = initial_state
- additional_inputs += initial_state
- self.state_spec = [
- InputSpec(shape=K.int_shape(state)) for state in initial_state
- ]
- additional_specs += self.state_spec
- if constants is not None:
- kwargs['constants'] = constants
- additional_inputs += constants
- self.constants_spec = [
- InputSpec(shape=K.int_shape(constant)) for constant in constants
- ]
- self._num_constants = len(constants)
- additional_specs += self.constants_spec
- # at this point additional_inputs cannot be empty
- is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
- for tensor in additional_inputs:
+ if not isinstance(initial_state, (list, tuple)):
+ initial_state = [initial_state]
+
+ is_keras_tensor = hasattr(initial_state[0], '_keras_history')
+ for tensor in initial_state:
if hasattr(tensor, '_keras_history') != is_keras_tensor:
- raise ValueError('The initial state or constants of an RNN'
- ' layer cannot be specified with a mix of'
- ' Keras tensors and non-Keras tensors')
+ raise ValueError('The initial state of an RNN layer cannot be'
+ ' specified with a mix of Keras tensors and'
+ ' non-Keras tensors')
if is_keras_tensor:
- # Compute the full input spec, including state and constants
- full_input = [inputs] + additional_inputs
- full_input_spec = self.input_spec + additional_specs
- # Perform the call with temporarily replaced input_spec
- original_input_spec = self.input_spec
- self.input_spec = full_input_spec
- output = super(RNN, self).__call__(full_input, **kwargs)
- self.input_spec = original_input_spec
+ # Compute the full input spec, including state
+ input_spec = self.input_spec
+ state_spec = self.state_spec
+ if not isinstance(input_spec, list):
+ input_spec = [input_spec]
+ if not isinstance(state_spec, list):
+ state_spec = [state_spec]
+ self.input_spec = input_spec + state_spec
+
+ # Compute the full inputs, including state
+ inputs = [inputs] + list(initial_state)
+
+ # Perform the call
+ output = super(Recurrent, self).__call__(inputs, **kwargs)
+
+ # Restore original input spec
+ self.input_spec = input_spec
return output
else:
- return super(RNN, self).__call__(inputs, **kwargs)
-
- def call(self,
- inputs,
- mask=None,
- training=None,
- initial_state=None,
- constants=None):
+ kwargs['initial_state'] = initial_state
+ return super(Recurrent, self).__call__(inputs, **kwargs)
+
+ def call(self, inputs, mask=None, training=None, initial_state=None):
# input shape: `(samples, time (padded with zeros), input_dim)`
# note that the .build() method of subclasses MUST define
# self.input_spec and self.state_spec with complete input shapes.
if isinstance(inputs, list):
+ initial_state = inputs[1:]
inputs = inputs[0]
- if initial_state is not None:
+ elif initial_state is not None:
pass
elif self.stateful:
initial_state = self.states
@@ -568,14 +343,13 @@ class RNN(Layer):
mask = mask[0]
if len(initial_state) != len(self.states):
- raise ValueError(
- 'Layer has ' + str(len(self.states)) + ' states but was passed ' +
- str(len(initial_state)) + ' initial states.')
+ raise ValueError('Layer has ' + str(len(self.states)) +
+ ' states but was passed ' + str(len(initial_state)) +
+ ' initial states.')
input_shape = K.int_shape(inputs)
- timesteps = input_shape[1]
- if self.unroll and timesteps in [None, 1]:
+ if self.unroll and input_shape[1] is None:
raise ValueError('Cannot unroll a RNN if the '
- 'time dimension is undefined or equal to 1. \n'
+ 'time dimension is undefined. \n'
'- If using a Sequential model, '
'specify the time dimension by passing '
'an `input_shape` or `batch_input_shape` '
@@ -585,31 +359,15 @@ class RNN(Layer):
'- If using the functional API, specify '
'the time dimension by passing a `shape` '
'or `batch_shape` argument to your Input layer.')
-
- kwargs = {}
- if has_arg(self.cell.call, 'training'):
- kwargs['training'] = training
-
- if constants:
- if not has_arg(self.cell.call, 'constants'):
- raise ValueError('RNN cell does not support constants')
-
- def step(inputs, states):
- constants = states[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
- states = states[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
- return self.cell.call(inputs, states, constants=constants, **kwargs)
- else:
-
- def step(inputs, states):
- return self.cell.call(inputs, states, **kwargs)
-
+ constants = self.get_constants(inputs, training=None)
+ preprocessed_input = self.preprocess_input(inputs, training=None)
last_output, outputs, states = K.rnn(
- step,
- inputs,
+ self.step,
+ preprocessed_input,
initial_state,
- constants=constants,
go_backwards=self.go_backwards,
mask=mask,
+ constants=constants,
unroll=self.unroll)
if self.stateful:
updates = []
@@ -617,63 +375,21 @@ class RNN(Layer):
updates.append((self.states[i], states[i]))
self.add_update(updates, inputs)
- if self.return_sequences:
- output = outputs
- else:
- output = last_output
-
# Properly set learning phase
- if getattr(last_output, '_uses_learning_phase', False):
- output._uses_learning_phase = True
+ if 0 < self.dropout + self.recurrent_dropout:
+ last_output._uses_learning_phase = True
+ outputs._uses_learning_phase = True
+
+ if not self.return_sequences:
+ outputs = last_output
if self.return_state:
if not isinstance(states, (list, tuple)):
states = [states]
else:
states = list(states)
- return [output] + states
- else:
- return output
-
- def _standardize_args(self, inputs, initial_state, constants):
- """Standardize `__call__` arguments to a single list of tensor inputs.
-
- When running a model loaded from file, the input tensors
- `initial_state` and `constants` can be passed to `RNN.__call__` as part
- of `inputs` instead of by the dedicated keyword arguments. This method
- makes sure the arguments are separated and that `initial_state` and
- `constants` are lists of tensors (or None).
-
- Arguments:
- inputs: tensor or list/tuple of tensors
- initial_state: tensor or list of tensors or None
- constants: tensor or list of tensors or None
-
- Returns:
- inputs: tensor
- initial_state: list of tensors or None
- constants: list of tensors or None
- """
- if isinstance(inputs, list):
- assert initial_state is None and constants is None
- if self._num_constants is not None:
- constants = inputs[-self._num_constants:] # pylint: disable=invalid-unary-operand-type
- inputs = inputs[:-self._num_constants] # pylint: disable=invalid-unary-operand-type
- if len(inputs) > 1:
- initial_state = inputs[1:]
- inputs = inputs[0]
-
- def to_list_or_none(x):
- if x is None or isinstance(x, list):
- return x
- if isinstance(x, tuple):
- return list(x)
- return [x]
-
- initial_state = to_list_or_none(initial_state)
- constants = to_list_or_none(constants)
-
- return inputs, initial_state, constants
+ return [outputs] + states
+ return outputs
def reset_states(self, states=None):
if not self.stateful:
@@ -692,19 +408,10 @@ class RNN(Layer):
'`batch_shape` argument to your Input layer.')
# initialize state if None
if self.states[0] is None:
- if hasattr(self.cell.state_size, '__len__'):
- self.states = [
- K.zeros((batch_size, dim)) for dim in self.cell.state_size
- ]
- else:
- self.states = [K.zeros((batch_size, self.cell.state_size))]
+ self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
elif states is None:
- if hasattr(self.cell.state_size, '__len__'):
- for state, dim in zip(self.states, self.cell.state_size):
- K.set_value(state, np.zeros((batch_size, dim)))
- else:
- K.set_value(self.states[0], np.zeros((batch_size,
- self.cell.state_size)))
+ for state in self.states:
+ K.set_value(state, np.zeros((batch_size, self.units)))
else:
if not isinstance(states, (list, tuple)):
states = [states]
@@ -714,16 +421,11 @@ class RNN(Layer):
'but it received ' + str(len(states)) +
' state values. Input received: ' + str(states))
for index, (value, state) in enumerate(zip(states, self.states)):
- if hasattr(self.cell.state_size, '__len__'):
- dim = self.cell.state_size[index]
- else:
- dim = self.cell.state_size
- if value.shape != (batch_size, dim):
- raise ValueError(
- 'State ' + str(index) + ' is incompatible with layer ' +
- self.name + ': expected shape=' + str(
- (batch_size, dim)) + ', found shape=' + str(value.shape))
- # TODO(fchollet): consider batch calls to `set_value`.
+ if value.shape != (batch_size, self.units):
+ raise ValueError('State ' + str(index) +
+ ' is incompatible with layer ' + self.name +
+ ': expected shape=' + str((batch_size, self.units)) +
+ ', found shape=' + str(value.shape))
K.set_value(state, value)
def get_config(self):
@@ -732,94 +434,51 @@ class RNN(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll
- }
- if self._num_constants is not None:
- config['num_constants'] = self._num_constants
-
- cell_config = self.cell.get_config()
- config['cell'] = {
- 'class_name': self.cell.__class__.__name__,
- 'config': cell_config
+ 'unroll': self.unroll,
+ 'implementation': self.implementation
}
- base_config = super(RNN, self).get_config()
+ base_config = super(Recurrent, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @classmethod
- def from_config(cls, config, custom_objects=None):
- from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
- cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects)
- num_constants = config.pop('num_constants', None)
- layer = cls(cell, **config)
- layer._num_constants = num_constants
- return layer
-
- @property
- def trainable_weights(self):
- if isinstance(self.cell, Layer):
- return self.cell.trainable_weights
- return []
-
- @property
- def non_trainable_weights(self):
- if isinstance(self.cell, Layer):
- return self.cell.non_trainable_weights
- return []
- @property
- def losses(self):
- if isinstance(self.cell, Layer):
- return self.cell.losses
- return []
-
- def get_losses_for(self, inputs=None):
- if isinstance(self.cell, Layer):
- cell_losses = self.cell.get_losses_for(inputs)
- return cell_losses + super(RNN, self).get_losses_for(inputs)
- return super(RNN, self).get_losses_for(inputs)
-
-
-class SimpleRNNCell(Layer):
- """Cell class for SimpleRNN.
+class SimpleRNN(Recurrent):
+ """Fully-connected RNN where the output is to be fed back to input.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
+ If you don't specify anything, no activation is applied
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the recurrent state..
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
+
+ References:
+ - [A Theoretically Grounded Application of Dropout in Recurrent Neural
+ Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -832,13 +491,15 @@ class SimpleRNNCell(Layer):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
+ activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
**kwargs):
- super(SimpleRNNCell, self).__init__(**kwargs)
+ super(SimpleRNN, self).__init__(
+ activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
self.units = units
self.activation = activations.get(activation)
self.use_bias = use_bias
@@ -857,13 +518,23 @@ class SimpleRNNCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.state_size = self.units
- self._dropout_mask = None
- self._recurrent_dropout_mask = None
+ self.state_spec = InputSpec(shape=(None, self.units))
def build(self, input_shape):
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+
+ batch_size = input_shape[0] if self.stateful else None
+ self.input_dim = input_shape[2]
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
+
+ self.states = [None]
+ if self.stateful:
+ self.reset_states()
+
self.kernel = self.add_weight(
- shape=(input_shape[-1], self.units),
+ shape=(self.input_dim, self.units),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -885,327 +556,146 @@ class SimpleRNNCell(Layer):
self.bias = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
+ def preprocess_input(self, inputs, training=None):
+ if self.implementation > 0:
+ return inputs
else:
- self._dropout_mask = None
+ input_shape = inputs.get_shape().as_list()
+ input_dim = input_shape[2]
+ timesteps = input_shape[1]
+ return _time_distributed_dense(
+ inputs,
+ self.kernel,
+ self.bias,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
- if 0 < self.recurrent_dropout < 1:
- ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
- ones = K.tile(ones, (1, self.units))
-
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
-
- self._recurrent_dropout_mask = K.in_train_phase(
- dropped_inputs, ones, training=training)
+ def step(self, inputs, states):
+ if self.implementation == 0:
+ h = inputs
else:
- self._recurrent_dropout_mask = None
+ if 0 < self.dropout < 1:
+ h = K.dot(inputs * states[1], self.kernel)
+ else:
+ h = K.dot(inputs, self.kernel)
+ if self.bias is not None:
+ h = K.bias_add(h, self.bias)
- def call(self, inputs, states, training=None):
prev_output = states[0]
- dp_mask = self._dropout_mask
- rec_dp_mask = self._recurrent_dropout_mask
-
- if dp_mask is not None:
- h = K.dot(inputs * dp_mask, self.kernel)
- else:
- h = K.dot(inputs, self.kernel)
- if self.bias is not None:
- h = K.bias_add(h, self.bias)
-
- if rec_dp_mask is not None:
- prev_output *= rec_dp_mask
+ if 0 < self.recurrent_dropout < 1:
+ prev_output *= states[2]
output = h + K.dot(prev_output, self.recurrent_kernel)
if self.activation is not None:
output = self.activation(output)
# Properly set learning phase on output tensor.
if 0 < self.dropout + self.recurrent_dropout:
- if training is None:
- output._uses_learning_phase = True
+ output._uses_learning_phase = True
return output, [output]
+ def get_constants(self, inputs, training=None):
+ constants = []
+ if self.implementation != 0 and 0 < self.dropout < 1:
+ input_shape = K.int_shape(inputs)
+ input_dim = input_shape[-1]
+ ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+ ones = K.tile(ones, (1, int(input_dim)))
-class SimpleRNN(RNN):
- """Fully-connected RNN where the output is to be fed back to input.
-
- Arguments:
- units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- If you pass None, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
- use_bias: Boolean, whether the layer uses a bias vector.
- kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
- recurrent_initializer: Initializer for the `recurrent_kernel`
- weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
- kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
- kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
- recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
- dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the inputs.
- recurrent_dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the recurrent state.
- return_sequences: Boolean. Whether to return the last output.
- in the output sequence, or the full sequence.
- return_state: Boolean. Whether to return the last state
- in addition to the output.
- go_backwards: Boolean (default False).
- If True, process the input sequence backwards and return the
- reversed sequence.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
- unroll: Boolean (default False).
- If True, the network will be unrolled,
- else a symbolic loop will be used.
- Unrolling can speed-up a RNN,
- although it tends to be more memory-intensive.
- Unrolling is only suitable for short sequences.
- """
-
- def __init__(self,
- units,
- activation='tanh',
- use_bias=True,
- kernel_initializer='glorot_uniform',
- recurrent_initializer='orthogonal',
- bias_initializer='zeros',
- kernel_regularizer=None,
- recurrent_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- kernel_constraint=None,
- recurrent_constraint=None,
- bias_constraint=None,
- dropout=0.,
- recurrent_dropout=0.,
- return_sequences=False,
- return_state=False,
- go_backwards=False,
- stateful=False,
- unroll=False,
- **kwargs):
- if 'implementation' in kwargs:
- kwargs.pop('implementation')
- logging.warning('The `implementation` argument '
- 'in `SimpleRNN` has been deprecated. '
- 'Please remove it from your layer call.')
- cell = SimpleRNNCell(
- units,
- activation=activation,
- use_bias=use_bias,
- kernel_initializer=kernel_initializer,
- recurrent_initializer=recurrent_initializer,
- bias_initializer=bias_initializer,
- kernel_regularizer=kernel_regularizer,
- recurrent_regularizer=recurrent_regularizer,
- bias_regularizer=bias_regularizer,
- kernel_constraint=kernel_constraint,
- recurrent_constraint=recurrent_constraint,
- bias_constraint=bias_constraint,
- dropout=dropout,
- recurrent_dropout=recurrent_dropout)
- super(SimpleRNN, self).__init__(
- cell,
- return_sequences=return_sequences,
- return_state=return_state,
- go_backwards=go_backwards,
- stateful=stateful,
- unroll=unroll,
- activity_regularizer=regularizers.get(activity_regularizer),
- **kwargs)
- # self.activity_regularizer = regularizers.get(activity_regularizer)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
- return super(SimpleRNN, self).call(
- inputs, mask=mask, training=training, initial_state=initial_state)
-
- @property
- def units(self):
- return self.cell.units
-
- @property
- def activation(self):
- return self.cell.activation
-
- @property
- def use_bias(self):
- return self.cell.use_bias
-
- @property
- def kernel_initializer(self):
- return self.cell.kernel_initializer
-
- @property
- def recurrent_initializer(self):
- return self.cell.recurrent_initializer
-
- @property
- def bias_initializer(self):
- return self.cell.bias_initializer
-
- @property
- def kernel_regularizer(self):
- return self.cell.kernel_regularizer
-
- @property
- def recurrent_regularizer(self):
- return self.cell.recurrent_regularizer
-
- @property
- def bias_regularizer(self):
- return self.cell.bias_regularizer
-
- @property
- def kernel_constraint(self):
- return self.cell.kernel_constraint
+ def dropped_inputs():
+ return K.dropout(ones, self.dropout)
- @property
- def recurrent_constraint(self):
- return self.cell.recurrent_constraint
+ dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
+ constants.append(dp_mask)
+ else:
+ constants.append(K.cast_to_floatx(1.))
- @property
- def bias_constraint(self):
- return self.cell.bias_constraint
+ if 0 < self.recurrent_dropout < 1:
+ ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+ ones = K.tile(ones, (1, self.units))
- @property
- def dropout(self):
- return self.cell.dropout
+ def dropped_inputs(): # pylint: disable=function-redefined
+ return K.dropout(ones, self.recurrent_dropout)
- @property
- def recurrent_dropout(self):
- return self.cell.recurrent_dropout
+ rec_dp_mask = K.in_train_phase(dropped_inputs, ones, training=training)
+ constants.append(rec_dp_mask)
+ else:
+ constants.append(K.cast_to_floatx(1.))
+ return constants
def get_config(self):
config = {
- 'units':
- self.units,
- 'activation':
- activations.serialize(self.activation),
- 'use_bias':
- self.use_bias,
- 'kernel_initializer':
- initializers.serialize(self.kernel_initializer),
+ 'units': self.units,
+ 'activation': activations.serialize(self.activation),
+ 'use_bias': self.use_bias,
+ 'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer':
- initializers.serialize(self.bias_initializer),
- 'kernel_regularizer':
- regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer':
- regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint':
- constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint':
- constraints.serialize(self.bias_constraint),
- 'dropout':
- self.dropout,
- 'recurrent_dropout':
- self.recurrent_dropout
+ 'bias_constraint': constraints.serialize(self.bias_constraint),
+ 'dropout': self.dropout,
+ 'recurrent_dropout': self.recurrent_dropout
}
base_config = super(SimpleRNN, self).get_config()
- del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
- @classmethod
- def from_config(cls, config):
- if 'implementation' in config:
- config.pop('implementation')
- return cls(**config)
+class GRU(Recurrent):
+ """Gated Recurrent Unit - Cho et al.
-class GRUCell(Layer):
- """Cell class for the GRU layer.
+ 2014.
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the recurrent state..
+ bias_initializer: Initializer for the bias vector.
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
- implementation: Implementation mode, either 1 or 2.
- Mode 1 will structure its operations as a larger number of
- smaller dot products and additions, whereas mode 2 will
- batch them into fewer, larger operations. These modes will
- have different performance profiles on different hardware and
- for different applications.
+
+ References:
+ - [On the Properties of Neural Machine Translation: Encoder-Decoder
+ Approaches](https://arxiv.org/abs/1409.1259)
+ - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
+ Modeling](http://arxiv.org/abs/1412.3555v1)
+ - [A Theoretically Grounded Application of Dropout in Recurrent Neural
+ Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -1219,14 +709,15 @@ class GRUCell(Layer):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
+ activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
- implementation=1,
**kwargs):
- super(GRUCell, self).__init__(**kwargs)
+ super(GRU, self).__init__(
+ activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -1246,15 +737,22 @@ class GRUCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.implementation = implementation
- self.state_size = self.units
- self._dropout_mask = None
- self._recurrent_dropout_mask = None
+ self.state_spec = InputSpec(shape=(None, self.units))
def build(self, input_shape):
- input_dim = input_shape[-1]
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ batch_size = input_shape[0] if self.stateful else None
+ self.input_dim = input_shape[2]
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
+
+ self.states = [None]
+ if self.stateful:
+ self.reset_states()
+
self.kernel = self.add_weight(
- shape=(input_dim, self.units * 3),
+ shape=(self.input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -1294,83 +792,89 @@ class GRUCell(Layer):
self.bias_h = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+ def preprocess_input(self, inputs, training=None):
+ if self.implementation == 0:
+ input_shape = inputs.get_shape().as_list()
+ input_dim = input_shape[2]
+ timesteps = input_shape[1]
+
+ x_z = _time_distributed_dense(
+ inputs,
+ self.kernel_z,
+ self.bias_z,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ x_r = _time_distributed_dense(
+ inputs,
+ self.kernel_r,
+ self.bias_r,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ x_h = _time_distributed_dense(
+ inputs,
+ self.kernel_h,
+ self.bias_h,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ return K.concatenate([x_z, x_r, x_h], axis=2)
+ else:
+ return inputs
+
+ def get_constants(self, inputs, training=None):
+ constants = []
+ if self.implementation != 0 and 0 < self.dropout < 1:
+ input_shape = K.int_shape(inputs)
+ input_dim = input_shape[-1]
+ ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+ ones = K.tile(ones, (1, int(input_dim)))
def dropped_inputs():
return K.dropout(ones, self.dropout)
- self._dropout_mask = [
+ dp_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(3)
]
+ constants.append(dp_mask)
else:
- self._dropout_mask = None
+ constants.append([K.cast_to_floatx(1.) for _ in range(3)])
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units))
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
+ def dropped_inputs(): # pylint: disable=function-redefined
+ return K.dropout(ones, self.recurrent_dropout)
- self._recurrent_dropout_mask = [
+ rec_dp_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(3)
]
+ constants.append(rec_dp_mask)
else:
- self._recurrent_dropout_mask = None
+ constants.append([K.cast_to_floatx(1.) for _ in range(3)])
+ return constants
- def call(self, inputs, states, training=None):
+ def step(self, inputs, states):
h_tm1 = states[0] # previous memory
+ dp_mask = states[1] # dropout matrices for recurrent units
+ rec_dp_mask = states[2]
- # dropout matrices for input units
- dp_mask = self._dropout_mask
- # dropout matrices for recurrent units
- rec_dp_mask = self._recurrent_dropout_mask
-
- if self.implementation == 1:
- if 0. < self.dropout < 1.:
- inputs_z = inputs * dp_mask[0]
- inputs_r = inputs * dp_mask[1]
- inputs_h = inputs * dp_mask[2]
- else:
- inputs_z = inputs
- inputs_r = inputs
- inputs_h = inputs
- x_z = K.dot(inputs_z, self.kernel_z)
- x_r = K.dot(inputs_r, self.kernel_r)
- x_h = K.dot(inputs_h, self.kernel_h)
- if self.use_bias:
- x_z = K.bias_add(x_z, self.bias_z)
- x_r = K.bias_add(x_r, self.bias_r)
- x_h = K.bias_add(x_h, self.bias_h)
-
- if 0. < self.recurrent_dropout < 1.:
- h_tm1_z = h_tm1 * rec_dp_mask[0]
- h_tm1_r = h_tm1 * rec_dp_mask[1]
- h_tm1_h = h_tm1 * rec_dp_mask[2]
- else:
- h_tm1_z = h_tm1
- h_tm1_r = h_tm1
- h_tm1_h = h_tm1
- z = self.recurrent_activation(
- x_z + K.dot(h_tm1_z, self.recurrent_kernel_z))
- r = self.recurrent_activation(
- x_r + K.dot(h_tm1_r, self.recurrent_kernel_r))
-
- hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h))
- else:
- if 0. < self.dropout < 1.:
- inputs *= dp_mask[0]
- matrix_x = K.dot(inputs, self.kernel)
+ if self.implementation == 2:
+ matrix_x = K.dot(inputs * dp_mask[0], self.kernel)
if self.use_bias:
matrix_x = K.bias_add(matrix_x, self.bias)
- if 0. < self.recurrent_dropout < 1.:
- h_tm1 *= rec_dp_mask[0]
- matrix_inner = K.dot(h_tm1, self.recurrent_kernel[:, :2 * self.units])
+ matrix_inner = K.dot(h_tm1 * rec_dp_mask[0],
+ self.recurrent_kernel[:, :2 * self.units])
x_z = matrix_x[:, :self.units]
x_r = matrix_x[:, self.units:2 * self.units]
@@ -1381,323 +885,116 @@ class GRUCell(Layer):
r = self.recurrent_activation(x_r + recurrent_r)
x_h = matrix_x[:, 2 * self.units:]
- recurrent_h = K.dot(r * h_tm1, self.recurrent_kernel[:, 2 * self.units:])
+ recurrent_h = K.dot(r * h_tm1 * rec_dp_mask[0],
+ self.recurrent_kernel[:, 2 * self.units:])
hh = self.activation(x_h + recurrent_h)
+ else:
+ if self.implementation == 0:
+ x_z = inputs[:, :self.units]
+ x_r = inputs[:, self.units:2 * self.units]
+ x_h = inputs[:, 2 * self.units:]
+ elif self.implementation == 1:
+ x_z = K.dot(inputs * dp_mask[0], self.kernel_z)
+ x_r = K.dot(inputs * dp_mask[1], self.kernel_r)
+ x_h = K.dot(inputs * dp_mask[2], self.kernel_h)
+ if self.use_bias:
+ x_z = K.bias_add(x_z, self.bias_z)
+ x_r = K.bias_add(x_r, self.bias_r)
+ x_h = K.bias_add(x_h, self.bias_h)
+ else:
+ raise ValueError('Unknown `implementation` mode.')
+ z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
+ self.recurrent_kernel_z))
+ r = self.recurrent_activation(x_r + K.dot(h_tm1 * rec_dp_mask[1],
+ self.recurrent_kernel_r))
+
+ hh = self.activation(x_h + K.dot(r * h_tm1 * rec_dp_mask[2],
+ self.recurrent_kernel_h))
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
- if training is None:
- h._uses_learning_phase = True
+ h._uses_learning_phase = True
return h, [h]
-
-class GRU(RNN):
- # pylint: disable=line-too-long
- """Gated Recurrent Unit - Cho et al.
-
- 2014.
-
- Arguments:
- units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- If you pass None, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
- recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- use_bias: Boolean, whether the layer uses a bias vector.
- kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
- recurrent_initializer: Initializer for the `recurrent_kernel`
- weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
- kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
- kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
- recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
- dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the inputs.
- recurrent_dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the recurrent state.
- implementation: Implementation mode, either 1 or 2.
- Mode 1 will structure its operations as a larger number of
- smaller dot products and additions, whereas mode 2 will
- batch them into fewer, larger operations. These modes will
- have different performance profiles on different hardware and
- for different applications.
- return_sequences: Boolean. Whether to return the last output.
- in the output sequence, or the full sequence.
- return_state: Boolean. Whether to return the last state
- in addition to the output.
- go_backwards: Boolean (default False).
- If True, process the input sequence backwards and return the
- reversed sequence.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
- unroll: Boolean (default False).
- If True, the network will be unrolled,
- else a symbolic loop will be used.
- Unrolling can speed-up a RNN,
- although it tends to be more memory-intensive.
- Unrolling is only suitable for short sequences.
-
- References:
- - [On the Properties of Neural Machine Translation: Encoder-Decoder Approaches](https://arxiv.org/abs/1409.1259)
- - [Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](http://arxiv.org/abs/1412.3555v1)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
- """
- # pylint: enable=line-too-long
-
- def __init__(self,
- units,
- activation='tanh',
- recurrent_activation='hard_sigmoid',
- use_bias=True,
- kernel_initializer='glorot_uniform',
- recurrent_initializer='orthogonal',
- bias_initializer='zeros',
- kernel_regularizer=None,
- recurrent_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- kernel_constraint=None,
- recurrent_constraint=None,
- bias_constraint=None,
- dropout=0.,
- recurrent_dropout=0.,
- implementation=1,
- return_sequences=False,
- return_state=False,
- go_backwards=False,
- stateful=False,
- unroll=False,
- **kwargs):
- if implementation == 0:
- logging.warning('`implementation=0` has been deprecated, '
- 'and now defaults to `implementation=1`.'
- 'Please update your layer call.')
- cell = GRUCell(
- units,
- activation=activation,
- recurrent_activation=recurrent_activation,
- use_bias=use_bias,
- kernel_initializer=kernel_initializer,
- recurrent_initializer=recurrent_initializer,
- bias_initializer=bias_initializer,
- kernel_regularizer=kernel_regularizer,
- recurrent_regularizer=recurrent_regularizer,
- bias_regularizer=bias_regularizer,
- kernel_constraint=kernel_constraint,
- recurrent_constraint=recurrent_constraint,
- bias_constraint=bias_constraint,
- dropout=dropout,
- recurrent_dropout=recurrent_dropout,
- implementation=implementation)
- super(GRU, self).__init__(
- cell,
- return_sequences=return_sequences,
- return_state=return_state,
- go_backwards=go_backwards,
- stateful=stateful,
- unroll=unroll,
- **kwargs)
- self.activity_regularizer = regularizers.get(activity_regularizer)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
- return super(GRU, self).call(
- inputs, mask=mask, training=training, initial_state=initial_state)
-
- @property
- def units(self):
- return self.cell.units
-
- @property
- def activation(self):
- return self.cell.activation
-
- @property
- def recurrent_activation(self):
- return self.cell.recurrent_activation
-
- @property
- def use_bias(self):
- return self.cell.use_bias
-
- @property
- def kernel_initializer(self):
- return self.cell.kernel_initializer
-
- @property
- def recurrent_initializer(self):
- return self.cell.recurrent_initializer
-
- @property
- def bias_initializer(self):
- return self.cell.bias_initializer
-
- @property
- def kernel_regularizer(self):
- return self.cell.kernel_regularizer
-
- @property
- def recurrent_regularizer(self):
- return self.cell.recurrent_regularizer
-
- @property
- def bias_regularizer(self):
- return self.cell.bias_regularizer
-
- @property
- def kernel_constraint(self):
- return self.cell.kernel_constraint
-
- @property
- def recurrent_constraint(self):
- return self.cell.recurrent_constraint
-
- @property
- def bias_constraint(self):
- return self.cell.bias_constraint
-
- @property
- def dropout(self):
- return self.cell.dropout
-
- @property
- def recurrent_dropout(self):
- return self.cell.recurrent_dropout
-
- @property
- def implementation(self):
- return self.cell.implementation
-
def get_config(self):
config = {
- 'units':
- self.units,
- 'activation':
- activations.serialize(self.activation),
+ 'units': self.units,
+ 'activation': activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias':
- self.use_bias,
- 'kernel_initializer':
- initializers.serialize(self.kernel_initializer),
+ 'use_bias': self.use_bias,
+ 'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer':
- initializers.serialize(self.bias_initializer),
- 'kernel_regularizer':
- regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer':
- regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint':
- constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint':
- constraints.serialize(self.bias_constraint),
- 'dropout':
- self.dropout,
- 'recurrent_dropout':
- self.recurrent_dropout,
- 'implementation':
- self.implementation
+ 'bias_constraint': constraints.serialize(self.bias_constraint),
+ 'dropout': self.dropout,
+ 'recurrent_dropout': self.recurrent_dropout
}
base_config = super(GRU, self).get_config()
- del base_config['cell']
return dict(list(base_config.items()) + list(config.items()))
- @classmethod
- def from_config(cls, config):
- if 'implementation' in config and config['implementation'] == 0:
- config['implementation'] = 1
- return cls(**config)
+class LSTM(Recurrent):
+ """Long-Short Term Memory unit - Hochreiter 1997.
-class LSTMCell(Layer):
- """Cell class for the LSTM layer.
+ For a step-by-step description of the algorithm, see
+ [this tutorial](http://deeplearning.net/tutorial/lstm.html).
Arguments:
units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
+ activation: Activation function to use.
If you pass None, no activation is applied
(ie. "linear" activation: `a(x) = x`).
recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
+ for the recurrent step.
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the inputs..
recurrent_initializer: Initializer for the `recurrent_kernel`
weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
+ used for the linear transformation of the recurrent state..
+ bias_initializer: Initializer for the bias vector.
unit_forget_bias: Boolean.
If True, add 1 to the bias of the forget gate at initialization.
Setting it to true will also force `bias_initializer="zeros"`.
This is recommended in [Jozefowicz et
al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
+ the `kernel` weights matrix.
recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_regularizer: Regularizer function applied to the bias vector.
+ activity_regularizer: Regularizer function applied to
+ the output of the layer (its "activation")..
kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
+ the `kernel` weights matrix.
recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
+ the `recurrent_kernel` weights matrix.
+ bias_constraint: Constraint function applied to the bias vector.
dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the inputs.
recurrent_dropout: Float between 0 and 1.
Fraction of the units to drop for
the linear transformation of the recurrent state.
- implementation: Implementation mode, either 1 or 2.
- Mode 1 will structure its operations as a larger number of
- smaller dot products and additions, whereas mode 2 will
- batch them into fewer, larger operations. These modes will
- have different performance profiles on different hardware and
- for different applications.
+
+ References:
+ - [Long short-term
+ memory]((http://www.bioinf.jku.at/publications/older/2604.pdf)
+ (original 1997 paper)
+ - [Supervised sequence labeling with recurrent neural
+ networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
+ - [A Theoretically Grounded Application of Dropout in Recurrent Neural
+ Networks](http://arxiv.org/abs/1512.05287)
"""
def __init__(self,
@@ -1712,14 +1009,15 @@ class LSTMCell(Layer):
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
+ activity_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
- implementation=1,
**kwargs):
- super(LSTMCell, self).__init__(**kwargs)
+ super(LSTM, self).__init__(
+ activity_regularizer=regularizers.get(activity_regularizer), **kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
@@ -1740,15 +1038,25 @@ class LSTMCell(Layer):
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
- self.implementation = implementation
- self.state_size = (self.units, self.units)
- self._dropout_mask = None
- self._recurrent_dropout_mask = None
+ self.state_spec = [
+ InputSpec(shape=(None, self.units)),
+ InputSpec(shape=(None, self.units))
+ ]
def build(self, input_shape):
- input_dim = input_shape[-1]
+ if isinstance(input_shape, list):
+ input_shape = input_shape[0]
+ input_shape = tensor_shape.TensorShape(input_shape).as_list()
+ batch_size = input_shape[0] if self.stateful else None
+ self.input_dim = input_shape[2]
+ self.input_spec[0] = InputSpec(shape=(batch_size, None, self.input_dim))
+
+ self.states = [None, None]
+ if self.stateful:
+ self.reset_states()
+
self.kernel = self.add_weight(
- shape=(input_dim, self.units * 4),
+ shape=(self.input_dim, self.units * 4),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
@@ -1804,90 +1112,96 @@ class LSTMCell(Layer):
self.bias_o = None
self.built = True
- def _generate_dropout_mask(self, inputs, training=None):
- if 0 < self.dropout < 1:
- ones = K.ones_like(K.squeeze(inputs[:, 0:1, :], axis=1))
+ def preprocess_input(self, inputs, training=None):
+ if self.implementation == 0:
+ input_shape = inputs.get_shape().as_list()
+ input_dim = input_shape[2]
+ timesteps = input_shape[1]
+
+ x_i = _time_distributed_dense(
+ inputs,
+ self.kernel_i,
+ self.bias_i,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ x_f = _time_distributed_dense(
+ inputs,
+ self.kernel_f,
+ self.bias_f,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ x_c = _time_distributed_dense(
+ inputs,
+ self.kernel_c,
+ self.bias_c,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ x_o = _time_distributed_dense(
+ inputs,
+ self.kernel_o,
+ self.bias_o,
+ self.dropout,
+ input_dim,
+ self.units,
+ timesteps,
+ training=training)
+ return K.concatenate([x_i, x_f, x_c, x_o], axis=2)
+ else:
+ return inputs
+
+ def get_constants(self, inputs, training=None):
+ constants = []
+ if self.implementation != 0 and 0 < self.dropout < 1:
+ input_shape = K.int_shape(inputs)
+ input_dim = input_shape[-1]
+ ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
+ ones = K.tile(ones, (1, int(input_dim)))
def dropped_inputs():
return K.dropout(ones, self.dropout)
- self._dropout_mask = [
+ dp_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(4)
]
+ constants.append(dp_mask)
else:
- self._dropout_mask = None
+ constants.append([K.cast_to_floatx(1.) for _ in range(4)])
- def _generate_recurrent_dropout_mask(self, inputs, training=None):
if 0 < self.recurrent_dropout < 1:
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
ones = K.tile(ones, (1, self.units))
- def dropped_inputs():
- return K.dropout(ones, self.dropout)
+ def dropped_inputs(): # pylint: disable=function-redefined
+ return K.dropout(ones, self.recurrent_dropout)
- self._recurrent_dropout_mask = [
+ rec_dp_mask = [
K.in_train_phase(dropped_inputs, ones, training=training)
for _ in range(4)
]
+ constants.append(rec_dp_mask)
else:
- self._recurrent_dropout_mask = None
-
- def call(self, inputs, states, training=None):
- # dropout matrices for input units
- dp_mask = self._dropout_mask
- # dropout matrices for recurrent units
- rec_dp_mask = self._recurrent_dropout_mask
-
- h_tm1 = states[0] # previous memory state
- c_tm1 = states[1] # previous carry state
-
- if self.implementation == 1:
- if 0 < self.dropout < 1.:
- inputs_i = inputs * dp_mask[0]
- inputs_f = inputs * dp_mask[1]
- inputs_c = inputs * dp_mask[2]
- inputs_o = inputs * dp_mask[3]
- else:
- inputs_i = inputs
- inputs_f = inputs
- inputs_c = inputs
- inputs_o = inputs
- x_i = K.dot(inputs_i, self.kernel_i)
- x_f = K.dot(inputs_f, self.kernel_f)
- x_c = K.dot(inputs_c, self.kernel_c)
- x_o = K.dot(inputs_o, self.kernel_o)
- if self.use_bias:
- x_i = K.bias_add(x_i, self.bias_i)
- x_f = K.bias_add(x_f, self.bias_f)
- x_c = K.bias_add(x_c, self.bias_c)
- x_o = K.bias_add(x_o, self.bias_o)
-
- if 0 < self.recurrent_dropout < 1.:
- h_tm1_i = h_tm1 * rec_dp_mask[0]
- h_tm1_f = h_tm1 * rec_dp_mask[1]
- h_tm1_c = h_tm1 * rec_dp_mask[2]
- h_tm1_o = h_tm1 * rec_dp_mask[3]
- else:
- h_tm1_i = h_tm1
- h_tm1_f = h_tm1
- h_tm1_c = h_tm1
- h_tm1_o = h_tm1
- i = self.recurrent_activation(
- x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))
- f = self.recurrent_activation(
- x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))
- c = f * c_tm1 + i * self.activation(
- x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))
- o = self.recurrent_activation(
- x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
- else:
- if 0. < self.dropout < 1.:
- inputs *= dp_mask[0]
- z = K.dot(inputs, self.kernel)
- if 0. < self.recurrent_dropout < 1.:
- h_tm1 *= rec_dp_mask[0]
- z += K.dot(h_tm1, self.recurrent_kernel)
+ constants.append([K.cast_to_floatx(1.) for _ in range(4)])
+ return constants
+
+ def step(self, inputs, states):
+ h_tm1 = states[0]
+ c_tm1 = states[1]
+ dp_mask = states[2]
+ rec_dp_mask = states[3]
+
+ if self.implementation == 2:
+ z = K.dot(inputs * dp_mask[0], self.kernel)
+ z += K.dot(h_tm1 * rec_dp_mask[0], self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
@@ -1900,606 +1214,57 @@ class LSTMCell(Layer):
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
+ else:
+ if self.implementation == 0:
+ x_i = inputs[:, :self.units]
+ x_f = inputs[:, self.units:2 * self.units]
+ x_c = inputs[:, 2 * self.units:3 * self.units]
+ x_o = inputs[:, 3 * self.units:]
+ elif self.implementation == 1:
+ x_i = K.dot(inputs * dp_mask[0], self.kernel_i) + self.bias_i
+ x_f = K.dot(inputs * dp_mask[1], self.kernel_f) + self.bias_f
+ x_c = K.dot(inputs * dp_mask[2], self.kernel_c) + self.bias_c
+ x_o = K.dot(inputs * dp_mask[3], self.kernel_o) + self.bias_o
+ else:
+ raise ValueError('Unknown `implementation` mode.')
+ i = self.recurrent_activation(x_i + K.dot(h_tm1 * rec_dp_mask[0],
+ self.recurrent_kernel_i))
+ f = self.recurrent_activation(x_f + K.dot(h_tm1 * rec_dp_mask[1],
+ self.recurrent_kernel_f))
+ c = f * c_tm1 + i * self.activation(
+ x_c + K.dot(h_tm1 * rec_dp_mask[2], self.recurrent_kernel_c))
+ o = self.recurrent_activation(x_o + K.dot(h_tm1 * rec_dp_mask[3],
+ self.recurrent_kernel_o))
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
- if training is None:
- h._uses_learning_phase = True
+ h._uses_learning_phase = True
return h, [h, c]
-
-class LSTM(RNN):
- # pylint: disable=line-too-long
- """Long-Short Term Memory layer - Hochreiter 1997.
-
- Arguments:
- units: Positive integer, dimensionality of the output space.
- activation: Activation function to use
- (see [activations](../activations.md)).
- If you pass None, no activation is applied
- (ie. "linear" activation: `a(x) = x`).
- recurrent_activation: Activation function to use
- for the recurrent step
- (see [activations](../activations.md)).
- use_bias: Boolean, whether the layer uses a bias vector.
- kernel_initializer: Initializer for the `kernel` weights matrix,
- used for the linear transformation of the inputs.
- (see [initializers](../initializers.md)).
- recurrent_initializer: Initializer for the `recurrent_kernel`
- weights matrix,
- used for the linear transformation of the recurrent state.
- (see [initializers](../initializers.md)).
- bias_initializer: Initializer for the bias vector
- (see [initializers](../initializers.md)).
- unit_forget_bias: Boolean.
- If True, add 1 to the bias of the forget gate at initialization.
- Setting it to true will also force `bias_initializer="zeros"`.
- This is recommended in [Jozefowicz et
- al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)
- kernel_regularizer: Regularizer function applied to
- the `kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- recurrent_regularizer: Regularizer function applied to
- the `recurrent_kernel` weights matrix
- (see [regularizer](../regularizers.md)).
- bias_regularizer: Regularizer function applied to the bias vector
- (see [regularizer](../regularizers.md)).
- activity_regularizer: Regularizer function applied to
- the output of the layer (its "activation").
- (see [regularizer](../regularizers.md)).
- kernel_constraint: Constraint function applied to
- the `kernel` weights matrix
- (see [constraints](../constraints.md)).
- recurrent_constraint: Constraint function applied to
- the `recurrent_kernel` weights matrix
- (see [constraints](../constraints.md)).
- bias_constraint: Constraint function applied to the bias vector
- (see [constraints](../constraints.md)).
- dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the inputs.
- recurrent_dropout: Float between 0 and 1.
- Fraction of the units to drop for
- the linear transformation of the recurrent state.
- implementation: Implementation mode, either 1 or 2.
- Mode 1 will structure its operations as a larger number of
- smaller dot products and additions, whereas mode 2 will
- batch them into fewer, larger operations. These modes will
- have different performance profiles on different hardware and
- for different applications.
- return_sequences: Boolean. Whether to return the last output.
- in the output sequence, or the full sequence.
- return_state: Boolean. Whether to return the last state
- in addition to the output.
- go_backwards: Boolean (default False).
- If True, process the input sequence backwards and return the
- reversed sequence.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
- unroll: Boolean (default False).
- If True, the network will be unrolled,
- else a symbolic loop will be used.
- Unrolling can speed-up a RNN,
- although it tends to be more memory-intensive.
- Unrolling is only suitable for short sequences.
-
- References:
- - [Long short-term memory](http://www.bioinf.jku.at/publications/older/2604.pdf)
- - [Learning to forget: Continual prediction with LSTM](http://www.mitpressjournals.org/doi/pdf/10.1162/089976600300015015)
- - [Supervised sequence labeling with recurrent neural networks](http://www.cs.toronto.edu/~graves/preprint.pdf)
- - [A Theoretically Grounded Application of Dropout in Recurrent Neural Networks](http://arxiv.org/abs/1512.05287)
- """
- # pylint: enable=line-too-long
-
- def __init__(self,
- units,
- activation='tanh',
- recurrent_activation='hard_sigmoid',
- use_bias=True,
- kernel_initializer='glorot_uniform',
- recurrent_initializer='orthogonal',
- bias_initializer='zeros',
- unit_forget_bias=True,
- kernel_regularizer=None,
- recurrent_regularizer=None,
- bias_regularizer=None,
- activity_regularizer=None,
- kernel_constraint=None,
- recurrent_constraint=None,
- bias_constraint=None,
- dropout=0.,
- recurrent_dropout=0.,
- implementation=1,
- return_sequences=False,
- return_state=False,
- go_backwards=False,
- stateful=False,
- unroll=False,
- **kwargs):
- if implementation == 0:
- logging.warning('`implementation=0` has been deprecated, '
- 'and now defaults to `implementation=1`.'
- 'Please update your layer call.')
- cell = LSTMCell(
- units,
- activation=activation,
- recurrent_activation=recurrent_activation,
- use_bias=use_bias,
- kernel_initializer=kernel_initializer,
- recurrent_initializer=recurrent_initializer,
- unit_forget_bias=unit_forget_bias,
- bias_initializer=bias_initializer,
- kernel_regularizer=kernel_regularizer,
- recurrent_regularizer=recurrent_regularizer,
- bias_regularizer=bias_regularizer,
- kernel_constraint=kernel_constraint,
- recurrent_constraint=recurrent_constraint,
- bias_constraint=bias_constraint,
- dropout=dropout,
- recurrent_dropout=recurrent_dropout,
- implementation=implementation)
- super(LSTM, self).__init__(
- cell,
- return_sequences=return_sequences,
- return_state=return_state,
- go_backwards=go_backwards,
- stateful=stateful,
- unroll=unroll,
- **kwargs)
- self.activity_regularizer = regularizers.get(activity_regularizer)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
- self.cell._generate_dropout_mask(inputs, training=training)
- self.cell._generate_recurrent_dropout_mask(inputs, training=training)
- return super(LSTM, self).call(
- inputs, mask=mask, training=training, initial_state=initial_state)
-
- @property
- def units(self):
- return self.cell.units
-
- @property
- def activation(self):
- return self.cell.activation
-
- @property
- def recurrent_activation(self):
- return self.cell.recurrent_activation
-
- @property
- def use_bias(self):
- return self.cell.use_bias
-
- @property
- def kernel_initializer(self):
- return self.cell.kernel_initializer
-
- @property
- def recurrent_initializer(self):
- return self.cell.recurrent_initializer
-
- @property
- def bias_initializer(self):
- return self.cell.bias_initializer
-
- @property
- def unit_forget_bias(self):
- return self.cell.unit_forget_bias
-
- @property
- def kernel_regularizer(self):
- return self.cell.kernel_regularizer
-
- @property
- def recurrent_regularizer(self):
- return self.cell.recurrent_regularizer
-
- @property
- def bias_regularizer(self):
- return self.cell.bias_regularizer
-
- @property
- def kernel_constraint(self):
- return self.cell.kernel_constraint
-
- @property
- def recurrent_constraint(self):
- return self.cell.recurrent_constraint
-
- @property
- def bias_constraint(self):
- return self.cell.bias_constraint
-
- @property
- def dropout(self):
- return self.cell.dropout
-
- @property
- def recurrent_dropout(self):
- return self.cell.recurrent_dropout
-
- @property
- def implementation(self):
- return self.cell.implementation
-
def get_config(self):
config = {
- 'units':
- self.units,
- 'activation':
- activations.serialize(self.activation),
+ 'units': self.units,
+ 'activation': activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
- 'use_bias':
- self.use_bias,
- 'kernel_initializer':
- initializers.serialize(self.kernel_initializer),
+ 'use_bias': self.use_bias,
+ 'kernel_initializer': initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
- 'bias_initializer':
- initializers.serialize(self.bias_initializer),
- 'unit_forget_bias':
- self.unit_forget_bias,
- 'kernel_regularizer':
- regularizers.serialize(self.kernel_regularizer),
+ 'bias_initializer': initializers.serialize(self.bias_initializer),
+ 'unit_forget_bias': self.unit_forget_bias,
+ 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
- 'bias_regularizer':
- regularizers.serialize(self.bias_regularizer),
+ 'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer':
regularizers.serialize(self.activity_regularizer),
- 'kernel_constraint':
- constraints.serialize(self.kernel_constraint),
+ 'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
- 'bias_constraint':
- constraints.serialize(self.bias_constraint),
- 'dropout':
- self.dropout,
- 'recurrent_dropout':
- self.recurrent_dropout,
- 'implementation':
- self.implementation
+ 'bias_constraint': constraints.serialize(self.bias_constraint),
+ 'dropout': self.dropout,
+ 'recurrent_dropout': self.recurrent_dropout
}
base_config = super(LSTM, self).get_config()
- del base_config['cell']
- return dict(list(base_config.items()) + list(config.items()))
-
- @classmethod
- def from_config(cls, config):
- if 'implementation' in config and config['implementation'] == 0:
- config['implementation'] = 1
- return cls(**config)
-
-
-class Recurrent(Layer):
- """Deprecated abstract base class for recurrent layers.
-
- It still exists because it is leveraged by the convolutional-recurrent layers.
- It will be removed entirely in the future.
- It was never part of the public API.
- Do not use.
-
- Arguments:
- weights: list of Numpy arrays to set as initial weights.
- The list should have 3 elements, of shapes:
- `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`.
- return_sequences: Boolean. Whether to return the last output
- in the output sequence, or the full sequence.
- return_state: Boolean. Whether to return the last state
- in addition to the output.
- go_backwards: Boolean (default False).
- If True, process the input sequence backwards and return the
- reversed sequence.
- stateful: Boolean (default False). If True, the last state
- for each sample at index i in a batch will be used as initial
- state for the sample of index i in the following batch.
- unroll: Boolean (default False).
- If True, the network will be unrolled,
- else a symbolic loop will be used.
- Unrolling can speed-up a RNN,
- although it tends to be more memory-intensive.
- Unrolling is only suitable for short sequences.
- implementation: one of {0, 1, or 2}.
- If set to 0, the RNN will use
- an implementation that uses fewer, larger matrix products,
- thus running faster on CPU but consuming more memory.
- If set to 1, the RNN will use more matrix products,
- but smaller ones, thus running slower
- (may actually be faster on GPU) while consuming less memory.
- If set to 2 (LSTM/GRU only),
- the RNN will combine the input gate,
- the forget gate and the output gate into a single matrix,
- enabling more time-efficient parallelization on the GPU.
- Note: RNN dropout must be shared for all gates,
- resulting in a slightly reduced regularization.
- input_dim: dimensionality of the input (integer).
- This argument (or alternatively, the keyword argument `input_shape`)
- is required when using this layer as the first layer in a model.
- input_length: Length of input sequences, to be specified
- when it is constant.
- This argument is required if you are going to connect
- `Flatten` then `Dense` layers upstream
- (without it, the shape of the dense outputs cannot be computed).
- Note that if the recurrent layer is not the first layer
- in your model, you would need to specify the input length
- at the level of the first layer
- (e.g. via the `input_shape` argument)
-
- Input shape:
- 3D tensor with shape `(batch_size, timesteps, input_dim)`,
- (Optional) 2D tensors with shape `(batch_size, output_dim)`.
-
- Output shape:
- - if `return_state`: a list of tensors. The first tensor is
- the output. The remaining tensors are the last states,
- each with shape `(batch_size, units)`.
- - if `return_sequences`: 3D tensor with shape
- `(batch_size, timesteps, units)`.
- - else, 2D tensor with shape `(batch_size, units)`.
-
- # Masking
- This layer supports masking for input data with a variable number
- of timesteps. To introduce masks to your data,
- use an `Embedding` layer with the `mask_zero` parameter
- set to `True`.
-
- # Note on using statefulness in RNNs
- You can set RNN layers to be 'stateful', which means that the states
- computed for the samples in one batch will be reused as initial states
- for the samples in the next batch. This assumes a one-to-one mapping
- between samples in different successive batches.
-
- To enable statefulness:
- - specify `stateful=True` in the layer constructor.
- - specify a fixed batch size for your model, by passing
- if sequential model:
- `batch_input_shape=(...)` to the first layer in your model.
- else for functional model with 1 or more Input layers:
- `batch_shape=(...)` to all the first layers in your model.
- This is the expected shape of your inputs
- *including the batch size*.
- It should be a tuple of integers, e.g. `(32, 10, 100)`.
- - specify `shuffle=False` when calling fit().
-
- To reset the states of your model, call `.reset_states()` on either
- a specific layer, or on your entire model.
-
- # Note on specifying the initial state of RNNs
- You can specify the initial state of RNN layers symbolically by
- calling them with the keyword argument `initial_state`. The value of
- `initial_state` should be a tensor or list of tensors representing
- the initial state of the RNN layer.
-
- You can specify the initial state of RNN layers numerically by
- calling `reset_states` with the keyword argument `states`. The value of
- `states` should be a numpy array or list of numpy arrays representing
- the initial state of the RNN layer.
- """
-
- def __init__(self,
- return_sequences=False,
- return_state=False,
- go_backwards=False,
- stateful=False,
- unroll=False,
- implementation=0,
- **kwargs):
- super(Recurrent, self).__init__(**kwargs)
- self.return_sequences = return_sequences
- self.return_state = return_state
- self.go_backwards = go_backwards
- self.stateful = stateful
- self.unroll = unroll
- self.implementation = implementation
- self.supports_masking = True
- self.input_spec = [InputSpec(ndim=3)]
- self.state_spec = None
- self.dropout = 0
- self.recurrent_dropout = 0
-
- def _compute_output_shape(self, input_shape):
- if isinstance(input_shape, list):
- input_shape = input_shape[0]
- input_shape = tensor_shape.TensorShape(input_shape).as_list()
- if self.return_sequences:
- output_shape = (input_shape[0], input_shape[1], self.units)
- else:
- output_shape = (input_shape[0], self.units)
-
- if self.return_state:
- state_shape = [tensor_shape.TensorShape(
- (input_shape[0], self.units)) for _ in self.states]
- return [tensor_shape.TensorShape(output_shape)] + state_shape
- return tensor_shape.TensorShape(output_shape)
-
- def compute_mask(self, inputs, mask):
- if isinstance(mask, list):
- mask = mask[0]
- output_mask = mask if self.return_sequences else None
- if self.return_state:
- state_mask = [None for _ in self.states]
- return [output_mask] + state_mask
- return output_mask
-
- def step(self, inputs, states):
- raise NotImplementedError
-
- def get_constants(self, inputs, training=None):
- return []
-
- def get_initial_state(self, inputs):
- # build an all-zero tensor of shape (samples, output_dim)
- initial_state = K.zeros_like(inputs) # (samples, timesteps, input_dim)
- initial_state = K.sum(initial_state, axis=(1, 2)) # (samples,)
- initial_state = K.expand_dims(initial_state) # (samples, 1)
- initial_state = K.tile(initial_state, [1,
- self.units]) # (samples, output_dim)
- initial_state = [initial_state for _ in range(len(self.states))]
- return initial_state
-
- def preprocess_input(self, inputs, training=None):
- return inputs
-
- def __call__(self, inputs, initial_state=None, **kwargs):
- if (isinstance(inputs, (list, tuple)) and
- len(inputs) > 1
- and initial_state is None):
- initial_state = inputs[1:]
- inputs = inputs[0]
-
- # If `initial_state` is specified,
- # and if it a Keras tensor,
- # then add it to the inputs and temporarily
- # modify the input spec to include the state.
- if initial_state is None:
- return super(Recurrent, self).__call__(inputs, **kwargs)
-
- if not isinstance(initial_state, (list, tuple)):
- initial_state = [initial_state]
-
- is_keras_tensor = hasattr(initial_state[0], '_keras_history')
- for tensor in initial_state:
- if hasattr(tensor, '_keras_history') != is_keras_tensor:
- raise ValueError('The initial state of an RNN layer cannot be'
- ' specified with a mix of Keras tensors and'
- ' non-Keras tensors')
-
- if is_keras_tensor:
- # Compute the full input spec, including state
- input_spec = self.input_spec
- state_spec = self.state_spec
- if not isinstance(input_spec, list):
- input_spec = [input_spec]
- if not isinstance(state_spec, list):
- state_spec = [state_spec]
- self.input_spec = input_spec + state_spec
-
- # Compute the full inputs, including state
- inputs = [inputs] + list(initial_state)
-
- # Perform the call
- output = super(Recurrent, self).__call__(inputs, **kwargs)
-
- # Restore original input spec
- self.input_spec = input_spec
- return output
- else:
- kwargs['initial_state'] = initial_state
- return super(Recurrent, self).__call__(inputs, **kwargs)
-
- def call(self, inputs, mask=None, training=None, initial_state=None):
- # input shape: `(samples, time (padded with zeros), input_dim)`
- # note that the .build() method of subclasses MUST define
- # self.input_spec and self.state_spec with complete input shapes.
- if isinstance(inputs, list):
- initial_state = inputs[1:]
- inputs = inputs[0]
- elif initial_state is not None:
- pass
- elif self.stateful:
- initial_state = self.states
- else:
- initial_state = self.get_initial_state(inputs)
-
- if isinstance(mask, list):
- mask = mask[0]
-
- if len(initial_state) != len(self.states):
- raise ValueError('Layer has ' + str(len(self.states)) +
- ' states but was passed ' + str(len(initial_state)) +
- ' initial states.')
- input_shape = K.int_shape(inputs)
- if self.unroll and input_shape[1] is None:
- raise ValueError('Cannot unroll a RNN if the '
- 'time dimension is undefined. \n'
- '- If using a Sequential model, '
- 'specify the time dimension by passing '
- 'an `input_shape` or `batch_input_shape` '
- 'argument to your first layer. If your '
- 'first layer is an Embedding, you can '
- 'also use the `input_length` argument.\n'
- '- If using the functional API, specify '
- 'the time dimension by passing a `shape` '
- 'or `batch_shape` argument to your Input layer.')
- constants = self.get_constants(inputs, training=None)
- preprocessed_input = self.preprocess_input(inputs, training=None)
- last_output, outputs, states = K.rnn(
- self.step,
- preprocessed_input,
- initial_state,
- go_backwards=self.go_backwards,
- mask=mask,
- constants=constants,
- unroll=self.unroll)
- if self.stateful:
- updates = []
- for i in range(len(states)):
- updates.append((self.states[i], states[i]))
- self.add_update(updates, inputs)
-
- # Properly set learning phase
- if 0 < self.dropout + self.recurrent_dropout:
- last_output._uses_learning_phase = True
- outputs._uses_learning_phase = True
-
- if not self.return_sequences:
- outputs = last_output
-
- if self.return_state:
- if not isinstance(states, (list, tuple)):
- states = [states]
- else:
- states = list(states)
- return [outputs] + states
- return outputs
-
- def reset_states(self, states=None):
- if not self.stateful:
- raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
- if not batch_size:
- raise ValueError('If a RNN is stateful, it needs to know '
- 'its batch size. Specify the batch size '
- 'of your input tensors: \n'
- '- If using a Sequential model, '
- 'specify the batch size by passing '
- 'a `batch_input_shape` '
- 'argument to your first layer.\n'
- '- If using the functional API, specify '
- 'the time dimension by passing a '
- '`batch_shape` argument to your Input layer.')
- # initialize state if None
- if self.states[0] is None:
- self.states = [K.zeros((batch_size, self.units)) for _ in self.states]
- elif states is None:
- for state in self.states:
- K.set_value(state, np.zeros((batch_size, self.units)))
- else:
- if not isinstance(states, (list, tuple)):
- states = [states]
- if len(states) != len(self.states):
- raise ValueError('Layer ' + self.name + ' expects ' +
- str(len(self.states)) + ' states, '
- 'but it received ' + str(len(states)) +
- ' state values. Input received: ' + str(states))
- for index, (value, state) in enumerate(zip(states, self.states)):
- if value.shape != (batch_size, self.units):
- raise ValueError('State ' + str(index) +
- ' is incompatible with layer ' + self.name +
- ': expected shape=' + str((batch_size, self.units)) +
- ', found shape=' + str(value.shape))
- K.set_value(state, value)
-
- def get_config(self):
- config = {
- 'return_sequences': self.return_sequences,
- 'return_state': self.return_state,
- 'go_backwards': self.go_backwards,
- 'stateful': self.stateful,
- 'unroll': self.unroll,
- 'implementation': self.implementation
- }
- base_config = super(Recurrent, self).get_config()
return dict(list(base_config.items()) + list(config.items()))