aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/cudnn_rnn
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-08-30 12:19:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 12:23:47 -0700
commit8f1746c4b37441e5d6a080cd3c871bde913e9564 (patch)
tree4104315fcd99f0536765d61794a24a1809da2d80 /tensorflow/contrib/cudnn_rnn
parente952f2566662fff98b2c34e22b8c8398f9b7d450 (diff)
Exposing CudnnCompatibleRNN classes.
Also add reuse arg to LSTMBlockCell. PiperOrigin-RevId: 167030950
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r--tensorflow/contrib/cudnn_rnn/__init__.py7
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py5
2 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py
index 470661a9b1..87ba834770 100644
--- a/tensorflow/contrib/cudnn_rnn/__init__.py
+++ b/tensorflow/contrib/cudnn_rnn/__init__.py
@@ -14,6 +14,8 @@
# ==============================================================================
"""Ops for fused Cudnn RNN models.
+@@CudnnCompatibleGRUCell
+@@CudnnCompatibleLSTMCell
@@CudnnGRU
@@CudnnLSTM
@@CudnnRNNRelu
@@ -28,6 +30,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
@@ -36,9 +40,12 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable
+
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
+ "CudnnCompatibleGRUCell",
+ "CudnnCompatibleLSTMCell",
"CudnnGRU",
"CudnnLSTM",
"CudnnRNNRelu",
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 7794c371e1..694bd507d9 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -63,9 +63,10 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
this cell seamlessly.
"""
- def __init__(self, num_units):
+ def __init__(self, num_units, reuse=None):
super(CudnnCompatibleLSTMCell, self).__init__(
- num_units, forget_bias=0, clip_cell=False, use_peephole=False)
+ num_units, forget_bias=0, clip_cell=False, use_peephole=False,
+ reuse=reuse)
self._names.update({"scope": "cudnn_compatible_lstm_cell"})