diff options
-rw-r--r-- | tensorflow/docs_src/tutorials/recurrent.md | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 15 |
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md index 708a9620dd..346b6be06c 100644 --- a/tensorflow/docs_src/tutorials/recurrent.md +++ b/tensorflow/docs_src/tutorials/recurrent.md @@ -75,7 +75,9 @@ The basic pseudocode is as follows: words_in_dataset = tf.placeholder(tf.float32, [num_batches, batch_size, num_features]) lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size) # Initial state of the LSTM memory. -state = tf.zeros([batch_size, lstm.state_size]) +hidden_state = tf.zeros([batch_size, lstm.state_size]) +current_state = tf.zeros([batch_size, lstm.state_size]) +state = hidden_state, current_state probabilities = [] loss = 0.0 for current_batch_of_words in words_in_dataset: diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index ca69cddae2..c0d9c971a0 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -374,7 +374,20 @@ class BasicLSTMCell(RNNCell): return self._num_units def call(self, inputs, state): - """Long short-term memory cell (LSTM).""" + """Long short-term memory cell (LSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size x input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size x self.state_size]`, if `state_is_tuple` has been set to + `True`. Otherwise, a `Tensor` shaped + `[batch_size x 2 * self.state_size]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ sigmoid = math_ops.sigmoid # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: |