diff options
author | 2018-08-01 11:03:11 -0700 | |
---|---|---|
committer | 2018-08-01 11:11:37 -0700 | |
commit | 24331e4fc0166d2adf3cf3b155844b5c77500a0c (patch) | |
tree | fb6d77af5eba48c1582d86905df06f5800613ed3 | |
parent | 79bfcad510d0c57db1a78b44192647b8284cfce8 (diff) |
Remove deprecated Recurrent class from RNN.
PiperOrigin-RevId: 206956778
-rw-r--r-- | tensorflow/python/keras/layers/recurrent.py | 337 |
1 files changed, 0 insertions, 337 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 534c0eca08..a8bfdf25f2 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -23,7 +23,6 @@ import numbers import numpy as np from tensorflow.python.eager import context -from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import activations from tensorflow.python.keras import backend as K from tensorflow.python.keras import constraints @@ -2231,342 +2230,6 @@ def _generate_dropout_mask(ones, rate, training=None, count=1): return K.in_train_phase(dropped_inputs, ones, training=training) -class Recurrent(Layer): - """Deprecated abstract base class for recurrent layers. - - It still exists because it is leveraged by the convolutional-recurrent layers. - It will be removed entirely in the future. - It was never part of the public API. - Do not use. - - Arguments: - weights: list of Numpy arrays to set as initial weights. - The list should have 3 elements, of shapes: - `[(input_dim, output_dim), (output_dim, output_dim), (output_dim,)]`. - return_sequences: Boolean. Whether to return the last output - in the output sequence, or the full sequence. - return_state: Boolean. Whether to return the last state - in addition to the output. - go_backwards: Boolean (default False). - If True, process the input sequence backwards and return the - reversed sequence. - stateful: Boolean (default False). If True, the last state - for each sample at index i in a batch will be used as initial - state for the sample of index i in the following batch. - unroll: Boolean (default False). - If True, the network will be unrolled, - else a symbolic loop will be used. - Unrolling can speed-up a RNN, - although it tends to be more memory-intensive. - Unrolling is only suitable for short sequences. - implementation: one of {0, 1, or 2}. - If set to 0, the RNN will use - an implementation that uses fewer, larger matrix products, - thus running faster on CPU but consuming more memory. - If set to 1, the RNN will use more matrix products, - but smaller ones, thus running slower - (may actually be faster on GPU) while consuming less memory. - If set to 2 (LSTM/GRU only), - the RNN will combine the input gate, - the forget gate and the output gate into a single matrix, - enabling more time-efficient parallelization on the GPU. - Note: RNN dropout must be shared for all gates, - resulting in a slightly reduced regularization. - input_dim: dimensionality of the input (integer). - This argument (or alternatively, the keyword argument `input_shape`) - is required when using this layer as the first layer in a model. - input_length: Length of input sequences, to be specified - when it is constant. - This argument is required if you are going to connect - `Flatten` then `Dense` layers upstream - (without it, the shape of the dense outputs cannot be computed). - Note that if the recurrent layer is not the first layer - in your model, you would need to specify the input length - at the level of the first layer - (e.g. via the `input_shape` argument) - - Input shape: - 3D tensor with shape `(batch_size, timesteps, input_dim)`, - (Optional) 2D tensors with shape `(batch_size, output_dim)`. - - Output shape: - - if `return_state`: a list of tensors. The first tensor is - the output. The remaining tensors are the last states, - each with shape `(batch_size, units)`. - - if `return_sequences`: 3D tensor with shape - `(batch_size, timesteps, units)`. - - else, 2D tensor with shape `(batch_size, units)`. - - # Masking - This layer supports masking for input data with a variable number - of timesteps. To introduce masks to your data, - use an `Embedding` layer with the `mask_zero` parameter - set to `True`. - - # Note on using statefulness in RNNs - You can set RNN layers to be 'stateful', which means that the states - computed for the samples in one batch will be reused as initial states - for the samples in the next batch. This assumes a one-to-one mapping - between samples in different successive batches. - - To enable statefulness: - - specify `stateful=True` in the layer constructor. - - specify a fixed batch size for your model, by passing - if sequential model: - `batch_input_shape=(...)` to the first layer in your model. - else for functional model with 1 or more Input layers: - `batch_shape=(...)` to all the first layers in your model. - This is the expected shape of your inputs - *including the batch size*. - It should be a tuple of integers, e.g. `(32, 10, 100)`. - - specify `shuffle=False` when calling fit(). - - To reset the states of your model, call `.reset_states()` on either - a specific layer, or on your entire model. - - # Note on specifying the initial state of RNNs - You can specify the initial state of RNN layers symbolically by - calling them with the keyword argument `initial_state`. The value of - `initial_state` should be a tensor or list of tensors representing - the initial state of the RNN layer. - - You can specify the initial state of RNN layers numerically by - calling `reset_states` with the keyword argument `states`. The value of - `states` should be a numpy array or list of numpy arrays representing - the initial state of the RNN layer. - """ - - def __init__(self, - return_sequences=False, - return_state=False, - go_backwards=False, - stateful=False, - unroll=False, - implementation=0, - **kwargs): - super(Recurrent, self).__init__(**kwargs) - self.return_sequences = return_sequences - self.return_state = return_state - self.go_backwards = go_backwards - self.stateful = stateful - self.unroll = unroll - self.implementation = implementation - self.supports_masking = True - self.input_spec = [InputSpec(ndim=3)] - self.state_spec = None - self.dropout = 0 - self.recurrent_dropout = 0 - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shape = input_shape[0] - input_shape = tensor_shape.TensorShape(input_shape).as_list() - if self.return_sequences: - output_shape = (input_shape[0], input_shape[1], self.units) - else: - output_shape = (input_shape[0], self.units) - - if self.return_state: - state_shape = [tensor_shape.TensorShape( - (input_shape[0], self.units)) for _ in self.states] - return [tensor_shape.TensorShape(output_shape)] + state_shape - return tensor_shape.TensorShape(output_shape) - - def compute_mask(self, inputs, mask): - if isinstance(mask, list): - mask = mask[0] - output_mask = mask if self.return_sequences else None - if self.return_state: - state_mask = [None for _ in self.states] - return [output_mask] + state_mask - return output_mask - - def step(self, inputs, states): - raise NotImplementedError - - def get_constants(self, inputs, training=None): - return [] - - def get_initial_state(self, inputs): - # build an all-zero tensor of shape (samples, output_dim) - initial_state = array_ops.zeros_like(inputs) - # shape of initial_state = (samples, timesteps, input_dim) - initial_state = math_ops.reduce_sum(initial_state, axis=(1, 2)) - # shape of initial_state = (samples,) - initial_state = array_ops.expand_dims(initial_state, axis=-1) - # shape of initial_state = (samples, 1) - initial_state = K.tile(initial_state, [1, - self.units]) # (samples, output_dim) - initial_state = [initial_state for _ in range(len(self.states))] - return initial_state - - def preprocess_input(self, inputs, training=None): - return inputs - - def __call__(self, inputs, initial_state=None, **kwargs): - if (isinstance(inputs, (list, tuple)) and - len(inputs) > 1 - and initial_state is None): - initial_state = inputs[1:] - inputs = inputs[0] - - # If `initial_state` is specified, - # and if it a Keras tensor, - # then add it to the inputs and temporarily - # modify the input spec to include the state. - if initial_state is None: - return super(Recurrent, self).__call__(inputs, **kwargs) - - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] - - is_keras_tensor = hasattr(initial_state[0], '_keras_history') - for tensor in initial_state: - if hasattr(tensor, '_keras_history') != is_keras_tensor: - raise ValueError('The initial state of an RNN layer cannot be' - ' specified with a mix of Keras tensors and' - ' non-Keras tensors') - - if is_keras_tensor: - # Compute the full input spec, including state - input_spec = self.input_spec - state_spec = self.state_spec - if not isinstance(input_spec, list): - input_spec = [input_spec] - if not isinstance(state_spec, list): - state_spec = [state_spec] - self.input_spec = input_spec + state_spec - - # Compute the full inputs, including state - inputs = [inputs] + list(initial_state) - - # Perform the call - output = super(Recurrent, self).__call__(inputs, **kwargs) - - # Restore original input spec - self.input_spec = input_spec - return output - else: - kwargs['initial_state'] = initial_state - return super(Recurrent, self).__call__(inputs, **kwargs) - - def call(self, inputs, mask=None, training=None, initial_state=None): - # input shape: `(samples, time (padded with zeros), input_dim)` - # note that the .build() method of subclasses MUST define - # self.input_spec and self.state_spec with complete input shapes. - if isinstance(inputs, list): - initial_state = inputs[1:] - inputs = inputs[0] - elif initial_state is not None: - pass - elif self.stateful: - initial_state = self.states - else: - initial_state = self.get_initial_state(inputs) - - if isinstance(mask, list): - mask = mask[0] - - if len(initial_state) != len(self.states): - raise ValueError('Layer has ' + str(len(self.states)) + - ' states but was passed ' + str(len(initial_state)) + - ' initial states.') - input_shape = K.int_shape(inputs) - if self.unroll and input_shape[1] is None: - raise ValueError('Cannot unroll a RNN if the ' - 'time dimension is undefined. \n' - '- If using a Sequential model, ' - 'specify the time dimension by passing ' - 'an `input_shape` or `batch_input_shape` ' - 'argument to your first layer. If your ' - 'first layer is an Embedding, you can ' - 'also use the `input_length` argument.\n' - '- If using the functional API, specify ' - 'the time dimension by passing a `shape` ' - 'or `batch_shape` argument to your Input layer.') - constants = self.get_constants(inputs, training=None) - preprocessed_input = self.preprocess_input(inputs, training=None) - last_output, outputs, states = K.rnn( - self.step, - preprocessed_input, - initial_state, - go_backwards=self.go_backwards, - mask=mask, - constants=constants, - unroll=self.unroll) - if self.stateful: - updates = [] - for i in range(len(states)): - updates.append(state_ops.assign(self.states[i], states[i])) - self.add_update(updates, inputs) - - # Properly set learning phase - if 0 < self.dropout + self.recurrent_dropout: - last_output._uses_learning_phase = True - outputs._uses_learning_phase = True - - if not self.return_sequences: - outputs = last_output - - if self.return_state: - if not isinstance(states, (list, tuple)): - states = [states] - else: - states = list(states) - return [outputs] + states - return outputs - - def reset_states(self, states=None): - if not self.stateful: - raise AttributeError('Layer must be stateful.') - batch_size = self.input_spec[0].shape[0] - if not batch_size: - raise ValueError('If a RNN is stateful, it needs to know ' - 'its batch size. Specify the batch size ' - 'of your input tensors: \n' - '- If using a Sequential model, ' - 'specify the batch size by passing ' - 'a `batch_input_shape` ' - 'argument to your first layer.\n' - '- If using the functional API, specify ' - 'the time dimension by passing a ' - '`batch_shape` argument to your Input layer.') - # initialize state if None - if self.states[0] is None: - self.states = [K.zeros((batch_size, self.units)) for _ in self.states] - elif states is None: - for state in self.states: - K.set_value(state, np.zeros((batch_size, self.units))) - else: - if not isinstance(states, (list, tuple)): - states = [states] - if len(states) != len(self.states): - raise ValueError('Layer ' + self.name + ' expects ' + - str(len(self.states)) + ' states, ' - 'but it received ' + str(len(states)) + - ' state values. Input received: ' + str(states)) - for index, (value, state) in enumerate(zip(states, self.states)): - if value.shape != (batch_size, self.units): - raise ValueError('State ' + str(index) + - ' is incompatible with layer ' + self.name + - ': expected shape=' + str((batch_size, self.units)) + - ', found shape=' + str(value.shape)) - K.set_value(state, value) - - def get_config(self): - config = { - 'return_sequences': self.return_sequences, - 'return_state': self.return_state, - 'go_backwards': self.go_backwards, - 'stateful': self.stateful, - 'unroll': self.unroll, - 'implementation': self.implementation - } - base_config = super(Recurrent, self).get_config() - return dict(list(base_config.items()) + list(config.items())) - - def _standardize_args(inputs, initial_state, constants, num_constants): """Standardizes `__call__` to a single list of tensor inputs. |