diff options
author | 2017-06-24 10:36:24 -0700 | |
---|---|---|
committer | 2017-06-24 10:39:58 -0700 | |
commit | efb91f85b8d6554fd0906ab397e6a9abe9bdca41 (patch) | |
tree | 957700b0f4596e35021788b5766639811733c872 | |
parent | 337b7c8c6262bcc2ef00f5a636d7020e632aa32d (diff) |
Made TensorFlow documentation on LSTMs slightly more accurate.
PiperOrigin-RevId: 160047054
-rw-r--r-- | tensorflow/docs_src/tutorials/recurrent.md | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 15 |
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md index 708a9620dd..346b6be06c 100644 --- a/tensorflow/docs_src/tutorials/recurrent.md +++ b/tensorflow/docs_src/tutorials/recurrent.md @@ -75,7 +75,9 @@ The basic pseudocode is as follows: words_in_dataset = tf.placeholder(tf.float32, [num_batches, batch_size, num_features]) lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size) # Initial state of the LSTM memory. -state = tf.zeros([batch_size, lstm.state_size]) +hidden_state = tf.zeros([batch_size, lstm.state_size]) +current_state = tf.zeros([batch_size, lstm.state_size]) +state = hidden_state, current_state probabilities = [] loss = 0.0 for current_batch_of_words in words_in_dataset: diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index ca69cddae2..c0d9c971a0 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -374,7 +374,20 @@ class BasicLSTMCell(RNNCell): return self._num_units def call(self, inputs, state): - """Long short-term memory cell (LSTM).""" + """Long short-term memory cell (LSTM). + + Args: + inputs: `2-D` tensor with shape `[batch_size x input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size x self.state_size]`, if `state_is_tuple` has been set to + `True`. Otherwise, a `Tensor` shaped + `[batch_size x 2 * self.state_size]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ sigmoid = math_ops.sigmoid # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: |