aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar James Qin <jamesqin@google.com>2017-08-30 12:19:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-30 12:23:47 -0700
commit8f1746c4b37441e5d6a080cd3c871bde913e9564 (patch)
tree4104315fcd99f0536765d61794a24a1809da2d80
parente952f2566662fff98b2c34e22b8c8398f9b7d450 (diff)
Exposing CudnnCompatibleRNN classes.
Also add reuse arg to LSTMBlockCell. PiperOrigin-RevId: 167030950
-rw-r--r--tensorflow/contrib/cudnn_rnn/__init__.py7
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py5
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py7
3 files changed, 16 insertions, 3 deletions
diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py
index 470661a9b1..87ba834770 100644
--- a/tensorflow/contrib/cudnn_rnn/__init__.py
+++ b/tensorflow/contrib/cudnn_rnn/__init__.py
@@ -14,6 +14,8 @@
# ==============================================================================
"""Ops for fused Cudnn RNN models.
+@@CudnnCompatibleGRUCell
+@@CudnnCompatibleLSTMCell
@@CudnnGRU
@@CudnnLSTM
@@CudnnRNNRelu
@@ -28,6 +30,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell
+from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM
@@ -36,9 +40,12 @@ from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable
from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable
+
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
+ "CudnnCompatibleGRUCell",
+ "CudnnCompatibleLSTMCell",
"CudnnGRU",
"CudnnLSTM",
"CudnnRNNRelu",
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index 7794c371e1..694bd507d9 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -63,9 +63,10 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell):
this cell seamlessly.
"""
- def __init__(self, num_units):
+ def __init__(self, num_units, reuse=None):
super(CudnnCompatibleLSTMCell, self).__init__(
- num_units, forget_bias=0, clip_cell=False, use_peephole=False)
+ num_units, forget_bias=0, clip_cell=False, use_peephole=False,
+ reuse=reuse)
self._names.update({"scope": "cudnn_compatible_lstm_cell"})
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 48c2c5a724..f591f7c84e 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -342,7 +342,8 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
num_units,
forget_bias=1.0,
clip_cell=True,
- use_peephole=False):
+ use_peephole=False,
+ reuse=None):
"""Initialize the basic LSTM cell.
Args:
@@ -351,10 +352,14 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell):
clip_cell: boolean, whether to apply cell clipping. See
`_lstm_block_cell()` for details.
use_peephole: Whether to use peephole connections or not.
+ reuse: (optional) boolean describing whether to reuse variables in an
+ existing scope. If not `True`, and the existing scope already has the
+ given variables, an error is raised.
When restoring from CudnnLSTM-trained checkpoints, must use
CudnnCompatibleLSTMBlockCell instead.
"""
+ super(LSTMBlockCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._forget_bias = forget_bias
self._use_peephole = use_peephole