diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-02-08 15:01:49 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-08 15:11:31 -0800 |
commit | f98264a1b9916e46a88089b605e962265ecde1a6 (patch) | |
tree | e2964fb10cf26960ceb0487880f3f6d245855d0d /tensorflow/contrib/rnn | |
parent | 3d432ddefe43a6527cef1ffdcdb785a4f7db4a10 (diff) |
[TF contrib RNN] Expose some rnn classes and functionality in contrib.
PiperOrigin-RevId: 185057994
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r-- | tensorflow/contrib/rnn/__init__.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/gru_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/lstm_ops.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 2 |
4 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index c568c6760f..67f31785b5 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -18,6 +18,7 @@ See @{$python/contrib.rnn} guide. <!--From core--> @@RNNCell +@@LayerRNNCell @@BasicRNNCell @@BasicLSTMCell @@GRUCell @@ -68,6 +69,10 @@ See @{$python/contrib.rnn} guide. @@static_bidirectional_rnn @@stack_bidirectional_dynamic_rnn @@stack_bidirectional_rnn + +<!--RNN utilities--> +@@transpose_batch_time +@@best_effort_input_batch_size """ from __future__ import absolute_import @@ -85,6 +90,8 @@ from tensorflow.contrib.rnn.python.ops.lstm_ops import * from tensorflow.contrib.rnn.python.ops.rnn import * from tensorflow.contrib.rnn.python.ops.rnn_cell import * +from tensorflow.python.ops.rnn import _best_effort_input_batch_size as best_effort_input_batch_size +from tensorflow.python.ops.rnn import _transpose_batch_time as transpose_batch_time from tensorflow.python.ops.rnn import static_bidirectional_rnn from tensorflow.python.ops.rnn import static_rnn from tensorflow.python.ops.rnn import static_state_saving_rnn diff --git a/tensorflow/contrib/rnn/python/ops/gru_ops.py b/tensorflow/contrib/rnn/python/ops/gru_ops.py index 4c964ec201..81ca12317b 100644 --- a/tensorflow/contrib/rnn/python/ops/gru_ops.py +++ b/tensorflow/contrib/rnn/python/ops/gru_ops.py @@ -32,7 +32,7 @@ from tensorflow.python.util.deprecation import deprecated_args _gru_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_gru_ops.so")) -LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access +LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name @ops.RegisterGradient("GRUBlockCell") diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index 04f342cd18..f700717394 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -34,7 +34,7 @@ from tensorflow.python.platform import resource_loader _lstm_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_lstm_ops.so")) -LayerRNNCell = rnn_cell_impl._LayerRNNCell # pylint: disable=invalid-name,protected-access +LayerRNNCell = rnn_cell_impl.LayerRNNCell # pylint: disable=invalid-name # pylint: disable=invalid-name diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index fe07493d0f..dce71c393a 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2682,7 +2682,7 @@ class LayerNormLSTMCell(rnn_cell_impl.RNNCell): return m, new_state -class SRUCell(rnn_cell_impl._LayerRNNCell): +class SRUCell(rnn_cell_impl.LayerRNNCell): """SRU, Simple Recurrent Unit Implementation based on |