diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 344 |
1 files changed, 328 insertions, 16 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index d4691f2c27..5e85c125df 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -76,6 +77,18 @@ def _get_sharded_variable(name, shape, dtype, num_shards): return shards +def _norm(g, b, inp, scope): + shape = inp.get_shape()[-1:] + gamma_init = init_ops.constant_initializer(g) + beta_init = init_ops.constant_initializer(b) + with vs.variable_scope(scope): + # Initialize beta and gamma for use by layer_norm. + vs.get_variable("gamma", shape=shape, initializer=gamma_init) + vs.get_variable("beta", shape=shape, initializer=beta_init) + normalized = layers.layer_norm(inp, reuse=True, scope=scope) + return normalized + + class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -102,13 +115,24 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): The class uses optional peep-hole connections, and an optional projection layer. + + Layer normalization implementation is based on: + + https://arxiv.org/abs/1607.06450. + + "Layer Normalization" + Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + and is applied before the internal nonlinearities. + """ def __init__(self, num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=True, - activation=math_ops.tanh, reuse=None): + activation=math_ops.tanh, reuse=None, + layer_norm=False, norm_gain=1.0, norm_shift=0.0): """Initialize the parameters for an LSTM cell. Args: @@ -135,6 +159,13 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. + layer_norm: If `True`, layer normalization will be applied. + norm_gain: float, The layer normalization gain initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + norm_shift: float, The layer normalization shift initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + + """ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: @@ -152,6 +183,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse + self._layer_norm = layer_norm + self._norm_gain = norm_gain + self._norm_shift = norm_shift if num_proj: self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj) @@ -220,9 +254,20 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): # j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat([inputs, m_prev], 1) - lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) + lstm_matrix = math_ops.matmul(cell_inputs, concat_w) + + # If layer nomalization is applied, do not add bias + if not self._layer_norm: + lstm_matrix = nn_ops.bias_add(lstm_matrix, b) + j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) + # Apply layer normalization + if self._layer_norm: + j = _norm(self._norm_gain, self._norm_shift, j, "transform") + f = _norm(self._norm_gain, self._norm_shift, f, "forget") + o = _norm(self._norm_gain, self._norm_shift, o, "output") + # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( @@ -236,6 +281,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): f_act = sigmoid(f + self._forget_bias) c = (f_act * c_prev + (1 - f_act) * self._activation(j)) + # Apply layer normalization + if self._layer_norm: + c = _norm(self._norm_gain, self._norm_shift, c, "state") + if self._use_peepholes: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: @@ -1301,8 +1350,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): self._keep_prob = dropout_keep_prob self._seed = dropout_prob_seed self._layer_norm = layer_norm - self._g = norm_gain - self._b = norm_shift + self._norm_gain = norm_gain + self._norm_shift = norm_shift self._reuse = reuse @property @@ -1313,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): def output_size(self): return self._num_units - def _norm(self, inp, scope): + def _norm(self, inp, scope, dtype=dtypes.float32): shape = inp.get_shape()[-1:] - gamma_init = init_ops.constant_initializer(self._g) - beta_init = init_ops.constant_initializer(self._b) + gamma_init = init_ops.constant_initializer(self._norm_gain) + beta_init = init_ops.constant_initializer(self._norm_shift) with vs.variable_scope(scope): # Initialize beta and gamma for use by layer_norm. - vs.get_variable("gamma", shape=shape, initializer=gamma_init) - vs.get_variable("beta", shape=shape, initializer=beta_init) + vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) + vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) normalized = layers.layer_norm(inp, reuse=True, scope=scope) return normalized def _linear(self, args): out_size = 4 * self._num_units proj_size = args.get_shape()[-1] - weights = vs.get_variable("kernel", [proj_size, out_size]) + dtype = args.dtype + weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) out = math_ops.matmul(args, weights) if not self._layer_norm: - bias = vs.get_variable("bias", [out_size]) + bias = vs.get_variable("bias", [out_size], dtype=dtype) out = nn_ops.bias_add(out, bias) return out @@ -1339,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): c, h = state args = array_ops.concat([inputs, h], 1) concat = self._linear(args) + dtype = args.dtype i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: - i = self._norm(i, "input") - j = self._norm(j, "transform") - f = self._norm(f, "forget") - o = self._norm(o, "output") + i = self._norm(i, "input", dtype=dtype) + j = self._norm(j, "transform", dtype=dtype) + f = self._norm(f, "forget", dtype=dtype) + o = self._norm(o, "output", dtype=dtype) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: @@ -1354,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: - new_c = self._norm(new_c, "state") + new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) @@ -2306,3 +2357,264 @@ class GLSTMCell(rnn_cell_impl.RNNCell): new_state = rnn_cell_impl.LSTMStateTuple(c, m) return m, new_state + + +class LayerNormLSTMCell(rnn_cell_impl.RNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + The default non-peephole implementation is based on: + + http://www.bioinf.jku.at/publications/older/2604.pdf + + S. Hochreiter and J. Schmidhuber. + "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Layer normalization implementation is based on: + + https://arxiv.org/abs/1607.06450. + + "Layer Normalization" + Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + and is applied before the internal nonlinearities. + + """ + + def __init__(self, num_units, + use_peepholes=False, cell_clip=None, + initializer=None, num_proj=None, proj_clip=None, + forget_bias=1.0, + activation=None, layer_norm=False, + norm_gain=1.0, norm_shift=0.0, reuse=None): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. + activation: Activation function of the inner states. Default: `tanh`. + layer_norm: If `True`, layer normalization will be applied. + norm_gain: float, The layer normalization gain initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + norm_shift: float, The layer normalization shift initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(LayerNormLSTMCell, self).__init__(_reuse=reuse) + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._forget_bias = forget_bias + self._activation = activation or math_ops.tanh + self._layer_norm = layer_norm + self._norm_gain = norm_gain + self._norm_shift = norm_shift + + if num_proj: + self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) + self._output_size = num_proj + else: + self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + + def _linear(self, + args, + output_size, + bias, + bias_initializer=None, + kernel_initializer=None, + layer_norm=False): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + layer_norm: boolean, whether to apply layer normalization. + + + Returns: + A 2D Tensor with shape [batch x output_size] taking value + sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + "kernel", [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(args, 1), weights) + if not bias: + return res + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + biases = vs.get_variable( + "bias", [output_size], + dtype=dtype, + initializer=bias_initializer) + + if not layer_norm: + res = nn_ops.bias_add(res, biases) + + return res + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: this must be a tuple of state Tensors, + both `2-D`, with column sizes `c_state` and + `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + (c_prev, m_prev) = state + + dtype = inputs.dtype + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + scope = vs.get_variable_scope() + with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True, + bias_initializer=None, layer_norm=self._layer_norm) + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + + if self._layer_norm: + i = _norm(self._norm_gain, self._norm_shift, i, "input") + j = _norm(self._norm_gain, self._norm_shift, j, "transform") + f = _norm(self._norm_gain, self._norm_shift, f, "forget") + o = _norm(self._norm_gain, self._norm_shift, o, "output") + + # Diagonal connections + if self._use_peepholes: + with vs.variable_scope(unit_scope) as projection_scope: + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) + else: + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) + + if self._layer_norm: + c = _norm(self._norm_gain, self._norm_shift, c, "state") + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection") as proj_scope: + m = self._linear(m, self._num_proj, bias=False) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) + return m, new_state |