aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/rnn_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py344
1 files changed, 328 insertions, 16 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index d4691f2c27..5e85c125df 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops import partitioned_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@@ -76,6 +77,18 @@ def _get_sharded_variable(name, shape, dtype, num_shards):
return shards
+def _norm(g, b, inp, scope):
+ shape = inp.get_shape()[-1:]
+ gamma_init = init_ops.constant_initializer(g)
+ beta_init = init_ops.constant_initializer(b)
+ with vs.variable_scope(scope):
+ # Initialize beta and gamma for use by layer_norm.
+ vs.get_variable("gamma", shape=shape, initializer=gamma_init)
+ vs.get_variable("beta", shape=shape, initializer=beta_init)
+ normalized = layers.layer_norm(inp, reuse=True, scope=scope)
+ return normalized
+
+
class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
"""Long short-term memory unit (LSTM) recurrent network cell.
@@ -102,13 +115,24 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The class uses optional peep-hole connections, and an optional projection
layer.
+
+ Layer normalization implementation is based on:
+
+ https://arxiv.org/abs/1607.06450.
+
+ "Layer Normalization"
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+
+ and is applied before the internal nonlinearities.
+
"""
def __init__(self, num_units, use_peepholes=False,
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=1, num_proj_shards=1,
forget_bias=1.0, state_is_tuple=True,
- activation=math_ops.tanh, reuse=None):
+ activation=math_ops.tanh, reuse=None,
+ layer_norm=False, norm_gain=1.0, norm_shift=0.0):
"""Initialize the parameters for an LSTM cell.
Args:
@@ -135,6 +159,13 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
+ layer_norm: If `True`, layer normalization will be applied.
+ norm_gain: float, The layer normalization gain initial value. If
+ `layer_norm` has been set to `False`, this argument will be ignored.
+ norm_shift: float, The layer normalization shift initial value. If
+ `layer_norm` has been set to `False`, this argument will be ignored.
+
+
"""
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
@@ -152,6 +183,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
self._state_is_tuple = state_is_tuple
self._activation = activation
self._reuse = reuse
+ self._layer_norm = layer_norm
+ self._norm_gain = norm_gain
+ self._norm_shift = norm_shift
if num_proj:
self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)
@@ -220,9 +254,20 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
# j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat([inputs, m_prev], 1)
- lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
+ lstm_matrix = math_ops.matmul(cell_inputs, concat_w)
+
+ # If layer nomalization is applied, do not add bias
+ if not self._layer_norm:
+ lstm_matrix = nn_ops.bias_add(lstm_matrix, b)
+
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
+ # Apply layer normalization
+ if self._layer_norm:
+ j = _norm(self._norm_gain, self._norm_shift, j, "transform")
+ f = _norm(self._norm_gain, self._norm_shift, f, "forget")
+ o = _norm(self._norm_gain, self._norm_shift, o, "output")
+
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
@@ -236,6 +281,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
f_act = sigmoid(f + self._forget_bias)
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
+ # Apply layer normalization
+ if self._layer_norm:
+ c = _norm(self._norm_gain, self._norm_shift, c, "state")
+
if self._use_peepholes:
m = sigmoid(o + w_o_diag * c) * self._activation(c)
else:
@@ -1301,8 +1350,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
self._keep_prob = dropout_keep_prob
self._seed = dropout_prob_seed
self._layer_norm = layer_norm
- self._g = norm_gain
- self._b = norm_shift
+ self._norm_gain = norm_gain
+ self._norm_shift = norm_shift
self._reuse = reuse
@property
@@ -1313,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
def output_size(self):
return self._num_units
- def _norm(self, inp, scope):
+ def _norm(self, inp, scope, dtype=dtypes.float32):
shape = inp.get_shape()[-1:]
- gamma_init = init_ops.constant_initializer(self._g)
- beta_init = init_ops.constant_initializer(self._b)
+ gamma_init = init_ops.constant_initializer(self._norm_gain)
+ beta_init = init_ops.constant_initializer(self._norm_shift)
with vs.variable_scope(scope):
# Initialize beta and gamma for use by layer_norm.
- vs.get_variable("gamma", shape=shape, initializer=gamma_init)
- vs.get_variable("beta", shape=shape, initializer=beta_init)
+ vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype)
+ vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype)
normalized = layers.layer_norm(inp, reuse=True, scope=scope)
return normalized
def _linear(self, args):
out_size = 4 * self._num_units
proj_size = args.get_shape()[-1]
- weights = vs.get_variable("kernel", [proj_size, out_size])
+ dtype = args.dtype
+ weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype)
out = math_ops.matmul(args, weights)
if not self._layer_norm:
- bias = vs.get_variable("bias", [out_size])
+ bias = vs.get_variable("bias", [out_size], dtype=dtype)
out = nn_ops.bias_add(out, bias)
return out
@@ -1339,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)
+ dtype = args.dtype
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
- i = self._norm(i, "input")
- j = self._norm(j, "transform")
- f = self._norm(f, "forget")
- o = self._norm(o, "output")
+ i = self._norm(i, "input", dtype=dtype)
+ j = self._norm(j, "transform", dtype=dtype)
+ f = self._norm(f, "forget", dtype=dtype)
+ o = self._norm(o, "output", dtype=dtype)
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
@@ -1354,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell):
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
- new_c = self._norm(new_c, "state")
+ new_c = self._norm(new_c, "state", dtype=dtype)
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
@@ -2306,3 +2357,264 @@ class GLSTMCell(rnn_cell_impl.RNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(c, m)
return m, new_state
+
+
+class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ The default non-peephole implementation is based on:
+
+ http://www.bioinf.jku.at/publications/older/2604.pdf
+
+ S. Hochreiter and J. Schmidhuber.
+ "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997.
+
+ The peephole implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ The class uses optional peep-hole connections, optional cell clipping, and
+ an optional projection layer.
+
+ Layer normalization implementation is based on:
+
+ https://arxiv.org/abs/1607.06450.
+
+ "Layer Normalization"
+ Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
+
+ and is applied before the internal nonlinearities.
+
+ """
+
+ def __init__(self, num_units,
+ use_peepholes=False, cell_clip=None,
+ initializer=None, num_proj=None, proj_clip=None,
+ forget_bias=1.0,
+ activation=None, layer_norm=False,
+ norm_gain=1.0, norm_shift=0.0, reuse=None):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is
+ provided, then the projected values are clipped elementwise to within
+ `[-proj_clip, proj_clip]`.
+ forget_bias: Biases of the forget gate are initialized by default to 1
+ in order to reduce the scale of forgetting at the beginning of
+ the training. Must set it manually to `0.0` when restoring from
+ CudnnLSTM trained checkpoints.
+ activation: Activation function of the inner states. Default: `tanh`.
+ layer_norm: If `True`, layer normalization will be applied.
+ norm_gain: float, The layer normalization gain initial value. If
+ `layer_norm` has been set to `False`, this argument will be ignored.
+ norm_shift: float, The layer normalization shift initial value. If
+ `layer_norm` has been set to `False`, this argument will be ignored.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+
+ When restoring from CudnnLSTM-trained checkpoints, must use
+ CudnnCompatibleLSTMCell instead.
+ """
+ super(LayerNormLSTMCell, self).__init__(_reuse=reuse)
+
+ self._num_units = num_units
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._proj_clip = proj_clip
+ self._forget_bias = forget_bias
+ self._activation = activation or math_ops.tanh
+ self._layer_norm = layer_norm
+ self._norm_gain = norm_gain
+ self._norm_shift = norm_shift
+
+ if num_proj:
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj))
+ self._output_size = num_proj
+ else:
+ self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units))
+ self._output_size = num_units
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+
+ def _linear(self,
+ args,
+ output_size,
+ bias,
+ bias_initializer=None,
+ kernel_initializer=None,
+ layer_norm=False):
+ """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable.
+
+ Args:
+ args: a 2D Tensor or a list of 2D, batch x n, Tensors.
+ output_size: int, second dimension of W[i].
+ bias: boolean, whether to add a bias term or not.
+ bias_initializer: starting value to initialize the bias
+ (default is all zeros).
+ kernel_initializer: starting value to initialize the weight.
+ layer_norm: boolean, whether to apply layer normalization.
+
+
+ Returns:
+ A 2D Tensor with shape [batch x output_size] taking value
+ sum_i(args[i] * W[i]), where each W[i] is a newly created Variable.
+
+ Raises:
+ ValueError: if some of the arguments has unspecified or wrong shape.
+ """
+ if args is None or (nest.is_sequence(args) and not args):
+ raise ValueError("`args` must be specified")
+ if not nest.is_sequence(args):
+ args = [args]
+
+ # Calculate the total size of arguments on dimension 1.
+ total_arg_size = 0
+ shapes = [a.get_shape() for a in args]
+ for shape in shapes:
+ if shape.ndims != 2:
+ raise ValueError("linear is expecting 2D arguments: %s" % shapes)
+ if shape[1].value is None:
+ raise ValueError("linear expects shape[1] to be provided for shape %s, "
+ "but saw %s" % (shape, shape[1]))
+ else:
+ total_arg_size += shape[1].value
+
+ dtype = [a.dtype for a in args][0]
+
+ # Now the computation.
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope) as outer_scope:
+ weights = vs.get_variable(
+ "kernel", [total_arg_size, output_size],
+ dtype=dtype,
+ initializer=kernel_initializer)
+ if len(args) == 1:
+ res = math_ops.matmul(args[0], weights)
+ else:
+ res = math_ops.matmul(array_ops.concat(args, 1), weights)
+ if not bias:
+ return res
+ with vs.variable_scope(outer_scope) as inner_scope:
+ inner_scope.set_partitioner(None)
+ if bias_initializer is None:
+ bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype)
+ biases = vs.get_variable(
+ "bias", [output_size],
+ dtype=dtype,
+ initializer=bias_initializer)
+
+ if not layer_norm:
+ res = nn_ops.bias_add(res, biases)
+
+ return res
+
+ def call(self, inputs, state):
+ """Run one step of LSTM.
+
+ Args:
+ inputs: input Tensor, 2D, batch x num_units.
+ state: this must be a tuple of state Tensors,
+ both `2-D`, with column sizes `c_state` and
+ `m_state`.
+
+ Returns:
+ A tuple containing:
+
+ - A `2-D, [batch x output_dim]`, Tensor representing the output of the
+ LSTM after reading `inputs` when previous state was `state`.
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - Tensor(s) representing the new state of LSTM after reading `inputs` when
+ the previous state was `state`. Same type and shape(s) as `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+ sigmoid = math_ops.sigmoid
+
+ (c_prev, m_prev) = state
+
+ dtype = inputs.dtype
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+ scope = vs.get_variable_scope()
+ with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True,
+ bias_initializer=None, layer_norm=self._layer_norm)
+ i, j, f, o = array_ops.split(
+ value=lstm_matrix, num_or_size_splits=4, axis=1)
+
+ if self._layer_norm:
+ i = _norm(self._norm_gain, self._norm_shift, i, "input")
+ j = _norm(self._norm_gain, self._norm_shift, j, "transform")
+ f = _norm(self._norm_gain, self._norm_shift, f, "forget")
+ o = _norm(self._norm_gain, self._norm_shift, o, "output")
+
+ # Diagonal connections
+ if self._use_peepholes:
+ with vs.variable_scope(unit_scope) as projection_scope:
+ w_f_diag = vs.get_variable(
+ "w_f_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = vs.get_variable(
+ "w_i_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = vs.get_variable(
+ "w_o_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
+ sigmoid(i + w_i_diag * c_prev) * self._activation(j))
+ else:
+ c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
+ self._activation(j))
+
+ if self._layer_norm:
+ c = _norm(self._norm_gain, self._norm_shift, c, "state")
+
+ if self._cell_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
+ # pylint: enable=invalid-unary-operand-type
+ if self._use_peepholes:
+ m = sigmoid(o + w_o_diag * c) * self._activation(c)
+ else:
+ m = sigmoid(o) * self._activation(c)
+
+ if self._num_proj is not None:
+ with vs.variable_scope("projection") as proj_scope:
+ m = self._linear(m, self._num_proj, bias=False)
+
+ if self._proj_clip is not None:
+ # pylint: disable=invalid-unary-operand-type
+ m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
+ # pylint: enable=invalid-unary-operand-type
+
+ new_state = (rnn_cell_impl.LSTMStateTuple(c, m))
+ return m, new_state