diff options
author | James Qin <jamesqin@google.com> | 2017-08-30 12:19:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-08-30 12:23:47 -0700 |
commit | 8f1746c4b37441e5d6a080cd3c871bde913e9564 (patch) | |
tree | 4104315fcd99f0536765d61794a24a1809da2d80 | |
parent | e952f2566662fff98b2c34e22b8c8398f9b7d450 (diff) |
Exposing CudnnCompatibleRNN classes.
Also add reuse arg to LSTMBlockCell.
PiperOrigin-RevId: 167030950
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/__init__.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py | 5 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 7 |
3 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 470661a9b1..87ba834770 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """Ops for fused Cudnn RNN models. +@@CudnnCompatibleGRUCell +@@CudnnCompatibleLSTMCell @@CudnnGRU @@CudnnLSTM @@CudnnRNNRelu @@ -28,6 +30,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell +from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM @@ -36,9 +40,12 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + "CudnnCompatibleGRUCell", + "CudnnCompatibleLSTMCell", "CudnnGRU", "CudnnLSTM", "CudnnRNNRelu", diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index 7794c371e1..694bd507d9 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -63,9 +63,10 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): this cell seamlessly. """ - def __init__(self, num_units): + def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, clip_cell=False, use_peephole=False) + num_units, forget_bias=0, clip_cell=False, use_peephole=False, + reuse=reuse) self._names.update({"scope": "cudnn_compatible_lstm_cell"}) diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 48c2c5a724..f591f7c84e 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -342,7 +342,8 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): num_units, forget_bias=1.0, clip_cell=True, - use_peephole=False): + use_peephole=False, + reuse=None): """Initialize the basic LSTM cell. Args: @@ -351,10 +352,14 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): clip_cell: boolean, whether to apply cell clipping. See `_lstm_block_cell()` for details. use_peephole: Whether to use peephole connections or not. + reuse: (optional) boolean describing whether to reuse variables in an + existing scope. If not `True`, and the existing scope already has the + given variables, an error is raised. When restoring from CudnnLSTM-trained checkpoints, must use CudnnCompatibleLSTMBlockCell instead. """ + super(LSTMBlockCell, self).__init__(_reuse=reuse) self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole |