aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-24 10:36:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-24 10:39:58 -0700
commitefb91f85b8d6554fd0906ab397e6a9abe9bdca41 (patch)
tree957700b0f4596e35021788b5766639811733c872
parent337b7c8c6262bcc2ef00f5a636d7020e632aa32d (diff)
Made TensorFlow documentation on LSTMs slightly more accurate.
PiperOrigin-RevId: 160047054
-rw-r--r--tensorflow/docs_src/tutorials/recurrent.md4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py15
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md
index 708a9620dd..346b6be06c 100644
--- a/tensorflow/docs_src/tutorials/recurrent.md
+++ b/tensorflow/docs_src/tutorials/recurrent.md
@@ -75,7 +75,9 @@ The basic pseudocode is as follows:
words_in_dataset = tf.placeholder(tf.float32, [num_batches, batch_size, num_features])
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
-state = tf.zeros([batch_size, lstm.state_size])
+hidden_state = tf.zeros([batch_size, lstm.state_size])
+current_state = tf.zeros([batch_size, lstm.state_size])
+state = hidden_state, current_state
probabilities = []
loss = 0.0
for current_batch_of_words in words_in_dataset:
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index ca69cddae2..c0d9c971a0 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -374,7 +374,20 @@ class BasicLSTMCell(RNNCell):
return self._num_units
def call(self, inputs, state):
- """Long short-term memory cell (LSTM)."""
+ """Long short-term memory cell (LSTM).
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size x input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
+ `True`. Otherwise, a `Tensor` shaped
+ `[batch_size x 2 * self.state_size]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (either a
+ `LSTMStateTuple` or a concatenated state, depending on
+ `state_is_tuple`).
+ """
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple: