aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-04-10 18:44:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 18:46:38 -0700
commit5ad9e4588874f30d0d079acc60e07f2eddc0480f (patch)
treeab800846cc505d867b2961578869aec97eeb81a3 /tensorflow/contrib/cudnn_rnn
parentfad74785d12ea7463e5d0474522cd7d754699656 (diff)
Merge changes from github.
PiperOrigin-RevId: 192388250
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py16
1 files changed, 10 insertions, 6 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 1dd490b386..c28c3a18e4 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -88,19 +88,23 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell):
Cudnn compatible GRU (from Cudnn library user guide):
```python
- r_t = sigma(x_t * W_r + h_t-1 * R_h + b_Wr + b_Rr) # reset gate
- u_t = sigma(x_t * W_u + h_t-1 * R_u + b_Wu + b_Ru) # update gate
- h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_Rh) + b_Wh) # new memory gate
- h_t = (1 - u_t) .* h'_t + u_t .* h_t-1
+ # reset gate
+ $$r_t = \sigma(x_t * W_r + h_t-1 * R_h + b_{Wr} + b_{Rr})$$
+ # update gate
+ $$u_t = \sigma(x_t * W_u + h_t-1 * R_u + b_{Wu} + b_{Ru})$$
+ # new memory gate
+ $$h'_t = tanh(x_t * W_h + r_t .* (h_t-1 * R_h + b_{Rh}) + b_{Wh})$$
+ $$h_t = (1 - u_t) .* h'_t + u_t .* h_t-1$$
```
Other GRU (see @{tf.nn.rnn_cell.GRUCell} and @{tf.contrib.rnn.GRUBlockCell}):
```python
- h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_Wh) # new memory gate
+ # new memory gate
+ \\(h'_t = tanh(x_t * W_h + (r_t .* h_t-1) * R_h + b_{Wh})\\)
```
which is not equivalent to Cudnn GRU: in addition to the extra bias term b_Rh,
```python
- r .* (h * R) != (r .* h) * R
+ \\(r .* (h * R) != (r .* h) * R\\)
```
"""