aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/rnn/rnn_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/models/rnn/rnn_cell.py')
-rw-r--r--tensorflow/models/rnn/rnn_cell.py605
1 files changed, 605 insertions, 0 deletions
diff --git a/tensorflow/models/rnn/rnn_cell.py b/tensorflow/models/rnn/rnn_cell.py
new file mode 100644
index 0000000000..55d417fc2b
--- /dev/null
+++ b/tensorflow/models/rnn/rnn_cell.py
@@ -0,0 +1,605 @@
+"""Module for constructing RNN Cells."""
+
+import math
+
+import tensorflow as tf
+
+from tensorflow.models.rnn import linear
+
+
+class RNNCell(object):
+ """Abstract object representing an RNN cell.
+
+ An RNN cell, in the most abstract setting, is anything that has
+ a state -- a vector of floats of size self.state_size -- and performs some
+ operation that takes inputs of size self.input_size. This operation
+ results in an output of size self.output_size and a new state.
+
+ This module provides a number of basic commonly used RNN cells, such as
+ LSTM (Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number
+ of operators that allow add dropouts, projections, or embeddings for inputs.
+ Constructing multi-layer cells is supported by a super-class, MultiRNNCell,
+ defined later. Every RNNCell must have the properties below and and
+ implement __call__ with the following signature.
+ """
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this RNN cell on inputs, starting from the given state.
+
+ Args:
+ inputs: 2D Tensor with shape [batch_size x self.input_size].
+ state: 2D Tensor with shape [batch_size x self.state_size].
+ scope: VariableScope for the created subgraph; defaults to class name.
+
+ Returns:
+ A pair containing:
+ - Output: A 2D Tensor with shape [batch_size x self.output_size]
+ - New state: A 2D Tensor with shape [batch_size x self.state_size].
+ """
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def input_size(self):
+ """Integer: size of inputs accepted by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def output_size(self):
+ """Integer: size of outputs produced by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ @property
+ def state_size(self):
+ """Integer: size of state used by this cell."""
+ raise NotImplementedError("Abstract method")
+
+ def zero_state(self, batch_size, dtype):
+ """Return state tensor (shape [batch_size x state_size]) filled with 0.
+
+ Args:
+ batch_size: int, float, or unit Tensor representing the batch size.
+ dtype: the data type to use for the state.
+
+ Returns:
+ A 2D Tensor of shape [batch_size x state_size] filled with zeros.
+ """
+ zeros = tf.zeros(tf.pack([batch_size, self.state_size]), dtype=dtype)
+ # The reshape below is a no-op, but it allows shape inference of shape[1].
+ return tf.reshape(zeros, [-1, self.state_size])
+
+
+class BasicRNNCell(RNNCell):
+ """The most basic RNN cell."""
+
+ def __init__(self, num_units):
+ self._num_units = num_units
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Most basic RNN: output = new_state = tanh(W * input + U * state + B)."""
+ with tf.variable_scope(scope or type(self).__name__): # "BasicRNNCell"
+ output = tf.tanh(linear.linear([inputs, state], self._num_units, True))
+ return output, output
+
+
+class GRUCell(RNNCell):
+ """Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
+
+ def __init__(self, num_units):
+ self._num_units = num_units
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Gated recurrent unit (GRU) with nunits cells."""
+ with tf.variable_scope(scope or type(self).__name__): # "GRUCell"
+ with tf.variable_scope("Gates"): # Reset gate and update gate.
+ # We start with bias of 1.0 to not reset and not udpate.
+ r, u = tf.split(1, 2, linear.linear([inputs, state],
+ 2 * self._num_units, True, 1.0))
+ r, u = tf.sigmoid(r), tf.sigmoid(u)
+ with tf.variable_scope("Candidate"):
+ c = tf.tanh(linear.linear([inputs, r * state], self._num_units, True))
+ new_h = u * state + (1 - u) * c
+ return new_h, new_h
+
+
+class BasicLSTMCell(RNNCell):
+ """Basic LSTM recurrent network cell.
+
+ The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf.
+
+ It does not allow cell clipping, a projection layer, and does not
+ use peep-hole connections: it is the basic baseline.
+
+ Biases of the forget gate are initialized by default to 1 in order to reduce
+ the scale of forgetting in the beginning of the training.
+ """
+
+ def __init__(self, num_units, forget_bias=1.0):
+ self._num_units = num_units
+ self._forget_bias = forget_bias
+
+ @property
+ def input_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ @property
+ def state_size(self):
+ return 2 * self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Long short-term memory cell (LSTM)."""
+ with tf.variable_scope(scope or type(self).__name__): # "BasicLSTMCell"
+ # Parameters of gates are concatenated into one multiply for efficiency.
+ c, h = tf.split(1, 2, state)
+ concat = linear.linear([inputs, h], 4 * self._num_units, True)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ i, j, f, o = tf.split(1, 4, concat)
+
+ new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
+ new_h = tf.tanh(new_c) * tf.sigmoid(o)
+
+ return new_h, tf.concat(1, [new_c, new_h])
+
+
+class LSTMCell(RNNCell):
+ """Long short-term memory unit (LSTM) recurrent network cell.
+
+ This implementation is based on:
+
+ https://research.google.com/pubs/archive/43905.pdf
+
+ Hasim Sak, Andrew Senior, and Francoise Beaufays.
+ "Long short-term memory recurrent neural network architectures for
+ large scale acoustic modeling." INTERSPEECH, 2014.
+
+ It uses peep-hole connections, optional cell clipping, and an optional
+ projection layer.
+ """
+
+ def __init__(self, num_units, input_size,
+ use_peepholes=False, cell_clip=None,
+ initializer=None, num_proj=None,
+ num_unit_shards=1, num_proj_shards=1):
+ """Initialize the parameters for an LSTM cell.
+
+ Args:
+ num_units: int, The number of units in the LSTM cell
+ input_size: int, The dimensionality of the inputs into the LSTM cell
+ use_peepholes: bool, set True to enable diagonal/peephole connections.
+ cell_clip: (optional) A float value, if provided the cell state is clipped
+ by this value prior to the cell output activation.
+ initializer: (optional) The initializer to use for the weight and
+ projection matrices.
+ num_proj: (optional) int, The output dimensionality for the projection
+ matrices. If None, no projection is performed.
+ num_unit_shards: How to split the weight matrix. If >1, the weight
+ matrix is stored across num_unit_shards.
+ Note that num_unit_shards must evenly divide num_units * 4.
+ num_proj_shards: How to split the projection matrix. If >1, the
+ projection matrix is stored across num_proj_shards.
+ Note that num_proj_shards must evenly divide num_proj
+ (if num_proj is not None).
+
+ Raises:
+ ValueError: if num_unit_shards doesn't divide 4 * num_units or
+ num_proj_shards doesn't divide num_proj
+ """
+ self._num_units = num_units
+ self._input_size = input_size
+ self._use_peepholes = use_peepholes
+ self._cell_clip = cell_clip
+ self._initializer = initializer
+ self._num_proj = num_proj
+ self._num_unit_shards = num_unit_shards
+ self._num_proj_shards = num_proj_shards
+
+ if (num_units * 4) % num_unit_shards != 0:
+ raise ValueError("num_unit_shards must evently divide 4 * num_units")
+ if num_proj and num_proj % num_proj_shards != 0:
+ raise ValueError("num_proj_shards must evently divide num_proj")
+
+ if num_proj:
+ self._state_size = num_units + num_proj
+ self._output_size = num_proj
+ else:
+ self._state_size = 2 * num_units
+ self._output_size = num_units
+
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ @property
+ def state_size(self):
+ return self._state_size
+
+ def __call__(self, input_, state, scope=None):
+ """Run one step of LSTM.
+
+ Args:
+ input_: input Tensor, 2D, batch x num_units.
+ state: state Tensor, 2D, batch x state_size.
+ scope: VariableScope for the created subgraph; defaults to "LSTMCell".
+
+ Returns:
+ A tuple containing:
+ - A 2D, batch x output_dim, Tensor representing the output of the LSTM
+ after reading "input_" when previous state was "state".
+ Here output_dim is:
+ num_proj if num_proj was set,
+ num_units otherwise.
+ - A 2D, batch x state_size, Tensor representing the new state of LSTM
+ after reading "input_" when previous state was "state".
+ """
+ num_proj = self._num_units if self._num_proj is None else self._num_proj
+
+ c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
+ m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
+
+ dtype = input_.dtype
+
+ unit_shard_size = (4 * self._num_units) / self._num_unit_shards
+
+ with tf.variable_scope(scope or type(self).__name__): # "LSTMCell"
+ w = tf.concat(
+ 1, [tf.get_variable("W_%d" % i,
+ shape=[self.input_size + num_proj,
+ unit_shard_size],
+ initializer=self._initializer,
+ dtype=dtype)
+ for i in range(self._num_unit_shards)])
+
+ b = tf.get_variable(
+ "B", shape=[4 * self._num_units],
+ initializer=tf.zeros_initializer, dtype=dtype)
+
+ # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+ cell_inputs = tf.concat(1, [input_, m_prev])
+ i, j, f, o = tf.split(1, 4, tf.nn.bias_add(tf.matmul(cell_inputs, w), b))
+
+ # Diagonal connections
+ if self._use_peepholes:
+ w_f_diag = tf.get_variable(
+ "W_F_diag", shape=[self._num_units], dtype=dtype)
+ w_i_diag = tf.get_variable(
+ "W_I_diag", shape=[self._num_units], dtype=dtype)
+ w_o_diag = tf.get_variable(
+ "W_O_diag", shape=[self._num_units], dtype=dtype)
+
+ if self._use_peepholes:
+ c = (tf.sigmoid(f + 1 + w_f_diag * c_prev) * c_prev +
+ tf.sigmoid(i + w_i_diag * c_prev) * tf.tanh(j))
+ else:
+ c = (tf.sigmoid(f + 1) * c_prev + tf.sigmoid(i) * tf.tanh(j))
+
+ if self._cell_clip is not None:
+ c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
+
+ if self._use_peepholes:
+ m = tf.sigmoid(o + w_o_diag * c) * tf.tanh(c)
+ else:
+ m = tf.sigmoid(o) * tf.tanh(c)
+
+ if self._num_proj is not None:
+ proj_shard_size = self._num_proj / self._num_proj_shards
+ w_proj = tf.concat(
+ 1, [tf.get_variable("W_P_%d" % i,
+ shape=[self._num_units, proj_shard_size],
+ initializer=self._initializer, dtype=dtype)
+ for i in range(self._num_proj_shards)])
+ # TODO(ebrevdo), use matmulsum
+ m = tf.matmul(m, w_proj)
+
+ return m, tf.concat(1, [c, m])
+
+
+class OutputProjectionWrapper(RNNCell):
+ """Operator adding an output projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your outputs in time,
+ do the projection on this batch-concated sequence, then split it
+ if needed or directly feed into a softmax.
+ """
+
+ def __init__(self, cell, output_size):
+ """Create a cell with output projection.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ output_size: integer, the size of the output after projection.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if output_size is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if output_size < 1:
+ raise ValueError("Parameter output_size must be > 0: %d." % output_size)
+ self._cell = cell
+ self._output_size = output_size
+
+ @property
+ def input_size(self):
+ return self._cell.input_size
+
+ @property
+ def output_size(self):
+ return self._output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and output projection on inputs, starting from state."""
+ output, res_state = self._cell(inputs, state)
+ # Default scope: "OutputProjectionWrapper"
+ with tf.variable_scope(scope or type(self).__name__):
+ projected = linear.linear(output, self._output_size, True)
+ return projected, res_state
+
+
+class InputProjectionWrapper(RNNCell):
+ """Operator adding an input projection to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the projection on this batch-concated sequence, then split it.
+ """
+
+ def __init__(self, cell, input_size):
+ """Create a cell with input projection.
+
+ Args:
+ cell: an RNNCell, a projection of inputs is added before it.
+ input_size: integer, the size of the inputs before projection.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if input_size is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if input_size < 1:
+ raise ValueError("Parameter input_size must be > 0: %d." % input_size)
+ self._cell = cell
+ self._input_size = input_size
+
+ @property
+ def input_size(self):
+ return self._input_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the input projection and then the cell."""
+ # Default scope: "InputProjectionWrapper"
+ with tf.variable_scope(scope or type(self).__name__):
+ projected = linear.linear(inputs, self._cell.input_size, True)
+ return self._cell(projected, state)
+
+
+class DropoutWrapper(RNNCell):
+ """Operator adding dropout to inputs and outputs of the given cell."""
+
+ def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
+ seed=None):
+ """Create a cell with added input and/or output dropout.
+
+ Dropout is never used on the state.
+
+ Args:
+ cell: an RNNCell, a projection to output_size is added to it.
+ input_keep_prob: unit Tensor or float between 0 and 1, input keep
+ probability; if it is float and 1, no input dropout will be added.
+ output_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is float and 1, no output dropout will be added.
+ seed: (optional) integer, the randomness seed.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if keep_prob is not between 0 and 1.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not a RNNCell.")
+ if (isinstance(input_keep_prob, float) and
+ not (input_keep_prob >= 0.0 and input_keep_prob <= 1.0)):
+ raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
+ % input_keep_prob)
+ if (isinstance(output_keep_prob, float) and
+ not (output_keep_prob >= 0.0 and output_keep_prob <= 1.0)):
+ raise ValueError("Parameter input_keep_prob must be between 0 and 1: %d"
+ % output_keep_prob)
+ self._cell = cell
+ self._input_keep_prob = input_keep_prob
+ self._output_keep_prob = output_keep_prob
+ self._seed = seed
+
+ @property
+ def input_size(self):
+ return self._cell.input_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state):
+ """Run the cell with the declared dropouts."""
+ if (not isinstance(self._input_keep_prob, float) or
+ self._input_keep_prob < 1):
+ inputs = tf.nn.dropout(inputs, self._input_keep_prob, seed=self._seed)
+ output, new_state = self._cell(inputs, state)
+ if (not isinstance(self._output_keep_prob, float) or
+ self._output_keep_prob < 1):
+ output = tf.nn.dropout(output, self._output_keep_prob, seed=self._seed)
+ return output, new_state
+
+
+class EmbeddingWrapper(RNNCell):
+ """Operator adding input embedding to the given cell.
+
+ Note: in many cases it may be more efficient to not use this wrapper,
+ but instead concatenate the whole sequence of your inputs in time,
+ do the embedding on this batch-concated sequence, then split it and
+ feed into your RNN.
+ """
+
+ def __init__(self, cell, embedding_classes=0, embedding=None,
+ initializer=None):
+ """Create a cell with an added input embedding.
+
+ Args:
+ cell: an RNNCell, an embedding will be put before its inputs.
+ embedding_classes: integer, how many symbols will be embedded.
+ embedding: Variable, the embedding to use; if None, a new embedding
+ will be created; if set, then embedding_classes is not required.
+ initializer: an initializer to use when creating the embedding;
+ if None, the initializer from variable scope or a default one is used.
+
+ Raises:
+ TypeError: if cell is not an RNNCell.
+ ValueError: if embedding_classes is not positive.
+ """
+ if not isinstance(cell, RNNCell):
+ raise TypeError("The parameter cell is not RNNCell.")
+ if embedding_classes < 1 and embedding is None:
+ raise ValueError("Pass embedding or embedding_classes must be > 0: %d."
+ % embedding_classes)
+ if embedding_classes > 0 and embedding is not None:
+ if embedding.size[0] != embedding_classes:
+ raise ValueError("You declared embedding_classes=%d but passed an "
+ "embedding for %d classes." % (embedding.size[0],
+ embedding_classes))
+ if embedding.size[1] != cell.input_size:
+ raise ValueError("You passed embedding with output size %d and a cell"
+ " that accepts size %d." % (embedding.size[1],
+ cell.input_size))
+ self._cell = cell
+ self._embedding_classes = embedding_classes
+ self._embedding = embedding
+ self._initializer = initializer
+
+ @property
+ def input_size(self):
+ return 1
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell on embedded inputs."""
+ with tf.variable_scope(scope or type(self).__name__): # "EmbeddingWrapper"
+ with tf.device("/cpu:0"):
+ if self._embedding:
+ embedding = self._embedding
+ else:
+ if self._initializer:
+ initializer = self._initializer
+ elif tf.get_variable_scope().initializer:
+ initializer = tf.get_variable_scope().initializer
+ else:
+ # Default initializer for embeddings should have variance=1.
+ sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
+ initializer = tf.random_uniform_initializer(-sqrt3, sqrt3)
+ embedding = tf.get_variable("embedding", [self._embedding_classes,
+ self._cell.input_size],
+ initializer=initializer)
+ embedded = tf.nn.embedding_lookup(embedding, tf.reshape(inputs, [-1]))
+ return self._cell(embedded, state)
+
+
+class MultiRNNCell(RNNCell):
+ """RNN cell composed sequentially of multiple simple cells."""
+
+ def __init__(self, cells):
+ """Create a RNN cell composed sequentially of a number of RNNCells.
+
+ Args:
+ cells: list of RNNCells that will be composed in this order.
+
+ Raises:
+ ValueError: if cells is empty (not allowed) or if their sizes don't match.
+ """
+ if not cells:
+ raise ValueError("Must specify at least one cell for MultiRNNCell.")
+ for i in xrange(len(cells) - 1):
+ if cells[i + 1].input_size != cells[i].output_size:
+ raise ValueError("In MultiRNNCell, the input size of each next"
+ " cell must match the output size of the previous one."
+ " Mismatched output size in cell %d." % i)
+ self._cells = cells
+
+ @property
+ def input_size(self):
+ return self._cells[0].input_size
+
+ @property
+ def output_size(self):
+ return self._cells[-1].output_size
+
+ @property
+ def state_size(self):
+ return sum([cell.state_size for cell in self._cells])
+
+ def __call__(self, inputs, state, scope=None):
+ """Run this multi-layer cell on inputs, starting from state."""
+ with tf.variable_scope(scope or type(self).__name__): # "MultiRNNCell"
+ cur_state_pos = 0
+ cur_inp = inputs
+ new_states = []
+ for i, cell in enumerate(self._cells):
+ with tf.variable_scope("Cell%d" % i):
+ cur_state = tf.slice(state, [0, cur_state_pos], [-1, cell.state_size])
+ cur_state_pos += cell.state_size
+ cur_inp, new_state = cell(cur_inp, cur_state)
+ new_states.append(new_state)
+ return cur_inp, tf.concat(1, new_states)