diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 257 |
1 files changed, 256 insertions, 1 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index e4667828cd..d7ae6621db 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import nn_impl from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -328,7 +329,7 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): def __init__(self, num_units, use_peepholes=False, cell_clip=None, initializer=None, num_unit_shards=1, forget_bias=1.0, - feature_size=None, frequency_skip=None, + feature_size=None, frequency_skip=1, reuse=None): """Initialize the parameters for an LSTM cell. @@ -2723,3 +2724,257 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): h = r * self._activation(c) + (1.0 - r) * inputs return h, c + + +class WeightNormLSTMCell(rnn_cell_impl.RNNCell): + """Weight normalized LSTM Cell. Adapted from `rnn_cell_impl.LSTMCell`. + + The weight-norm implementation is based on: + https://arxiv.org/abs/1602.07868 + Tim Salimans, Diederik P. Kingma. + Weight Normalization: A Simple Reparameterization to Accelerate + Training of Deep Neural Networks + + The default LSTM implementation based on: + http://www.bioinf.jku.at/publications/older/2604.pdf + S. Hochreiter and J. Schmidhuber. + "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + The class uses optional peephole connections, optional cell clipping + and an optional projection layer. + + The optional peephole implementation is based on: + https://research.google.com/pubs/archive/43905.pdf + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + """ + + def __init__(self, num_units, norm=True, use_peepholes=False, + cell_clip=None, initializer=None, num_proj=None, + proj_clip=None, forget_bias=1, activation=None, + reuse=None): + """Initialize the parameters of a weight-normalized LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + norm: If `True`, apply normalization to the weight matrices. If False, + the result is identical to that obtained from `rnn_cell_impl.LSTMCell` + use_peepholes: bool, set `True` to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + """ + super(WeightNormLSTMCell, self).__init__(_reuse=reuse) + + self._scope = 'wn_lstm_cell' + self._num_units = num_units + self._norm = norm + self._initializer = initializer + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._num_proj = num_proj + self._proj_clip = proj_clip + self._activation = activation or math_ops.tanh + self._forget_bias = forget_bias + + self._weights_variable_name = "kernel" + self._bias_variable_name = "bias" + + if num_proj: + self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_proj) + self._output_size = num_proj + else: + self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + def _normalize(self, weight, name): + """Apply weight normalization. + + Args: + weight: a 2D tensor with known number of columns. + name: string, variable name for the normalizer. + Returns: + A tensor with the same shape as `weight`. + """ + + output_size = weight.get_shape().as_list()[1] + g = vs.get_variable(name, [output_size], dtype=weight.dtype) + return nn_impl.l2_normalize(weight, dim=0) * g + + def _linear(self, args, + output_size, + norm, + bias, + bias_initializer=None, + kernel_initializer=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + + Returns: + A 2D Tensor with shape [batch x output_size] equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + self._weights_variable_name, [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if norm: + wn = [] + st = 0 + with ops.control_dependencies(None): + for i in range(len(args)): + en = st + shapes[i][1].value + wn.append(self._normalize(weights[st:en, :], + name='norm_{}'.format(i))) + st = en + + weights = array_ops.concat(wn, axis=0) + + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(args, 1), weights) + if not bias: + return res + + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + + biases = vs.get_variable( + self._bias_variable_name, [output_size], + dtype=dtype, + initializer=bias_initializer) + + return nn_ops.bias_add(res, biases) + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: A tuple of state Tensors, both `2-D`, with column sizes + `c_state` and `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + dtype = inputs.dtype + num_units = self._num_units + sigmoid = math_ops.sigmoid + c, h = state + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + with vs.variable_scope(self._scope, initializer=self._initializer): + + concat = self._linear([inputs, h], 4 * num_units, + norm=self._norm, bias=True) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) + + if self._use_peepholes: + w_f_diag = vs.get_variable("w_f_diag", shape=[num_units], dtype=dtype) + w_i_diag = vs.get_variable("w_i_diag", shape=[num_units], dtype=dtype) + w_o_diag = vs.get_variable("w_o_diag", shape=[num_units], dtype=dtype) + + new_c = (c * sigmoid(f + self._forget_bias + w_f_diag * c) + + sigmoid(i + w_i_diag * c) * self._activation(j)) + else: + new_c = (c * sigmoid(f + self._forget_bias) + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + new_c = clip_ops.clip_by_value(new_c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + new_h = sigmoid(o + w_o_diag * new_c) * self._activation(new_c) + else: + new_h = sigmoid(o) * self._activation(new_c) + + if self._num_proj is not None: + with vs.variable_scope("projection"): + new_h = self._linear(new_h, + self._num_proj, + norm=self._norm, + bias=False) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + new_h = clip_ops.clip_by_value(new_h, + -self._proj_clip, + self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) + return new_h, new_state |