diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py index 755ebd048b..f44302638e 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell_impl.py @@ -13,7 +13,14 @@ # limitations under the License. # ============================================================================== -"""Module implementing RNN Cells.""" +"""Module implementing RNN Cells. + +This module provides a number of basic commonly used RNN cells, such as LSTM +(Long Short Term Memory) or GRU (Gated Recurrent Unit), and a number of +operators that allow adding dropouts, projections, or embeddings for inputs. +Constructing multi-layer cells is supported by the class `MultiRNNCell`, or by +calling the `rnn` ops several times. +""" from __future__ import absolute_import from __future__ import division @@ -146,12 +153,12 @@ class GRUCell(RNNCell): with _checked_scope(self, scope or "gru_cell", reuse=self._reuse): with vs.variable_scope("gates"): # Reset gate and update gate. # We start with bias of 1.0 to not reset and not update. + value = sigmoid(_linear( + [inputs, state], 2 * self._num_units, True, 1.0)) r, u = array_ops.split( - value=_linear( - [inputs, state], 2 * self._num_units, True, 1.0), + value=value, num_or_size_splits=2, axis=1) - r, u = sigmoid(r), sigmoid(u) with vs.variable_scope("candidate"): c = self._activation(_linear([inputs, r * state], self._num_units, True)) |