diff options
author | 2017-08-30 12:19:30 -0700 | |
---|---|---|
committer | 2017-08-30 12:23:47 -0700 | |
commit | 8f1746c4b37441e5d6a080cd3c871bde913e9564 (patch) | |
tree | 4104315fcd99f0536765d61794a24a1809da2d80 /tensorflow/contrib/cudnn_rnn | |
parent | e952f2566662fff98b2c34e22b8c8398f9b7d450 (diff) |
Exposing CudnnCompatibleRNN classes.
Also add reuse arg to LSTMBlockCell.
PiperOrigin-RevId: 167030950
Diffstat (limited to 'tensorflow/contrib/cudnn_rnn')
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/__init__.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 5 |
2 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 470661a9b1..87ba834770 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for fused Cudnn RNN models. +@@CudnnCompatibleGRUCell +@@CudnnCompatibleLSTMCell @@CudnnGRU @@CudnnLSTM @@CudnnRNNRelu @@ -28,6 +30,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM @@ -36,9 +40,12 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + "CudnnCompatibleGRUCell", + "CudnnCompatibleLSTMCell", "CudnnGRU", "CudnnLSTM", "CudnnRNNRelu", diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 7794c371e1..694bd507d9 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -63,9 +63,10 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): this cell seamlessly. """ - def __init__(self, num_units): + def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, clip_cell=False, use_peephole=False) + num_units, forget_bias=0, clip_cell=False, use_peephole=False, + reuse=reuse) self._names.update({"scope": "cudnn_compatible_lstm_cell"}) |