diff options
Diffstat (limited to 'tensorflow/models/rnn/rnn.py')
-rw-r--r-- | tensorflow/models/rnn/rnn.py | 128 |
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow/models/rnn/rnn.py b/tensorflow/models/rnn/rnn.py new file mode 100644 index 0000000000..24582bcae7 --- /dev/null +++ b/tensorflow/models/rnn/rnn.py @@ -0,0 +1,128 @@ +"""RNN helpers for TensorFlow models.""" + +import tensorflow as tf + +from tensorflow.models.rnn import rnn_cell +from tensorflow.python.ops import control_flow_ops + + +def rnn(cell, inputs, initial_state=None, dtype=None, + sequence_length=None, scope=None): + """Creates a recurrent neural network specified by RNNCell "cell". + + The simplest form of RNN network generated is: + state = cell.zero_state(...) + outputs = [] + states = [] + for input_ in inputs: + output, state = cell(input_, state) + outputs.append(output) + states.append(state) + return (outputs, states) + + However, a few other options are available: + + An initial state can be provided. + If sequence_length is provided, dynamic calculation is performed. + + Dynamic calculation returns, at time t: + (t >= max(sequence_length) + ? (zeros(output_shape), zeros(state_shape)) + : cell(input, state) + + Thus saving computational time when unrolling past the max sequence length. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + initial_state: (optional) An initial state for the RNN. This must be + a tensor of appropriate type and shape [batch_size x cell.state_size]. + dtype: (optional) The data type for the initial state. Required if + initial_state is not provided. + sequence_length: An int64 vector (tensor) size [batch_size]. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + + if not isinstance(cell, rnn_cell.RNNCell): + raise TypeError("cell must be an instance of RNNCell") + if not isinstance(inputs, list): + raise TypeError("inputs must be a list") + if not inputs: + raise ValueError("inputs must not be empty") + + outputs = [] + states = [] + with tf.variable_scope(scope or "RNN"): + batch_size = tf.shape(inputs[0])[0] + if initial_state is not None: + state = initial_state + else: + if not dtype: + raise ValueError("If no initial_state is provided, dtype must be.") + state = cell.zero_state(batch_size, dtype) + + if sequence_length: # Prepare variables + zero_output_state = ( + tf.zeros(tf.pack([batch_size, cell.output_size]), + inputs[0].dtype), + tf.zeros(tf.pack([batch_size, cell.state_size]), + state.dtype)) + max_sequence_length = tf.reduce_max(sequence_length) + + output_state = (None, None) + for time, input_ in enumerate(inputs): + if time > 0: + tf.get_variable_scope().reuse_variables() + output_state = cell(input_, state) + if sequence_length: + (output, state) = control_flow_ops.cond( + time >= max_sequence_length, + lambda: zero_output_state, lambda: output_state) + else: + (output, state) = output_state + + outputs.append(output) + states.append(state) + + return (outputs, states) + + +def state_saving_rnn(cell, inputs, state_saver, state_name, + sequence_length=None, scope=None): + """RNN that accepts a state saver for time-truncated RNN calculation. + + Args: + cell: An instance of RNNCell. + inputs: A length T list of inputs, each a vector with shape [batch_size]. + state_saver: A StateSaver object. + state_name: The name to use with the state_saver. + sequence_length: (optional) An int64 vector (tensor) size [batch_size]. + See the documentation for rnn() for more details about sequence_length. + scope: VariableScope for the created subgraph; defaults to "RNN". + + Returns: + A pair (outputs, states) where: + outputs is a length T list of outputs (one for each input) + states is a length T list of states (one state following each input) + + Raises: + TypeError: If "cell" is not an instance of RNNCell. + ValueError: If inputs is None or an empty list. + """ + initial_state = state_saver.State(state_name) + (outputs, states) = rnn(cell, inputs, initial_state=initial_state, + sequence_length=sequence_length, scope=scope) + save_state = state_saver.SaveState(state_name, states[-1]) + with tf.control_dependencies([save_state]): + outputs[-1] = tf.identity(outputs[-1]) + + return (outputs, states) |