aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/docs_src/tutorials/recurrent.md4
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py15
2 files changed, 17 insertions, 2 deletions
diff --git a/tensorflow/docs_src/tutorials/recurrent.md b/tensorflow/docs_src/tutorials/recurrent.md
index 708a9620dd..346b6be06c 100644
--- a/tensorflow/docs_src/tutorials/recurrent.md
+++ b/tensorflow/docs_src/tutorials/recurrent.md
@@ -75,7 +75,9 @@ The basic pseudocode is as follows:
words_in_dataset = tf.placeholder(tf.float32, [num_batches, batch_size, num_features])
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
# Initial state of the LSTM memory.
-state = tf.zeros([batch_size, lstm.state_size])
+hidden_state = tf.zeros([batch_size, lstm.state_size])
+current_state = tf.zeros([batch_size, lstm.state_size])
+state = hidden_state, current_state
probabilities = []
loss = 0.0
for current_batch_of_words in words_in_dataset:
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index ca69cddae2..c0d9c971a0 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -374,7 +374,20 @@ class BasicLSTMCell(RNNCell):
return self._num_units
def call(self, inputs, state):
- """Long short-term memory cell (LSTM)."""
+ """Long short-term memory cell (LSTM).
+
+ Args:
+ inputs: `2-D` tensor with shape `[batch_size x input_size]`.
+ state: An `LSTMStateTuple` of state tensors, each shaped
+ `[batch_size x self.state_size]`, if `state_is_tuple` has been set to
+ `True`. Otherwise, a `Tensor` shaped
+ `[batch_size x 2 * self.state_size]`.
+
+ Returns:
+ A pair containing the new hidden state, and the new state (either a
+ `LSTMStateTuple` or a concatenated state, depending on
+ `state_is_tuple`).
+ """
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple: