aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-02-08 15:01:49 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-08 15:11:31 -0800
commitf98264a1b9916e46a88089b605e962265ecde1a6 (patch)
treee2964fb10cf26960ceb0487880f3f6d245855d0d /tensorflow/contrib/rnn
parent3d432ddefe43a6527cef1ffdcdb785a4f7db4a10 (diff)
[TF contrib RNN] Expose some rnn classes and functionality in contrib.
PiperOrigin-RevId: 185057994
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/__init__.py7
-rw-r--r--tensorflow/contrib/rnn/python/ops/gru_ops.py2
-rw-r--r--tensorflow/contrib/rnn/python/ops/lstm_ops.py2
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py2
4 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index c568c6760f..67f31785b5 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -18,6 +18,7 @@ See @{$python/contrib.rnn} guide.
<!--From core-->
@@RNNCell
+@@LayerRNNCell
@@BasicRNNCell
@@BasicLSTMCell
@@GRUCell
@@ -68,6 +69,10 @@ See @{$python/contrib.rnn} guide.
@@static_bidirectional_rnn
@@stack_bidirectional_dynamic_rnn
@@stack_bidirectional_rnn
+
+<!--RNN utilities-->
+@@transpose_batch_time
+@@best_effort_input_batch_size
"""
from __future__ import absolute_import
@@ -85,6 +90,8 @@ from tensorflow.contrib.rnn.python.ops.lstm_ops import *
from tensorflow.contrib.rnn.python.ops.rnn import *
from tensorflow.contrib.rnn.python.ops.rnn_cell import *
+from tensorflow.python.ops.rnn import _best_effort_input_batch_size as best_effort_input_batch_size
+from tensorflow.python.ops.rnn import _transpose_batch_time as transpose_batch_time
from tensorflow.python.ops.rnn import static_bidirectional_rnn
from tensorflow.python.ops.rnn import static_rnn
from tensorflow.python.ops.rnn import static_state_saving_rnn
diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py
index 4c964ec201..81ca12317b 100644
--- a/tensorflow/contrib/rnn/python/ops/gru_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py
@@ -32,7 +32,7 @@ from tensorflow.python.util.deprecation import deprecated_args
_gru_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_gru_ops.so"))
-LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access
+LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name
@ops.RegisterGradient("GRUBlockCell")
diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
index 04f342cd18..f700717394 100644
--- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py
+++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py
@@ -34,7 +34,7 @@ from tensorflow.python.platform import resource_loader
_lstm_ops_so = loader.load_op_library(
resource_loader.get_path_to_datafile("_lstm_ops.so"))
-LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access
+LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name
# pylint: disable=invalid-name
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index fe07493d0f..dce71c393a 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -2682,7 +2682,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell):
return m, new_state
-class SRUCell(rnn_cell_impl._LayerRNNCell):
+class SRUCell(rnn_cell_impl.LayerRNNCell):
"""SRU, Simple Recurrent Unit
Implementation based on