aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-08-07 08:50:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-07 08:54:19 -0700
commit7f666bb652063874134ed60b77edb4ddc85ec488 (patch)
tree5a48a33605fa1f8a2a69d57bb3585f1153ee8f68 /tensorflow/python/ops/rnn_cell_impl.py
parent335336aa2cdf853d380c3e22ab6694ff78cb487a (diff)
Deprecate BasicLSTMCell and push user to use LSTMCell instead.
PiperOrigin-RevId: 207722104
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 42806ba6ec..8356fbbb9d 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -48,6 +48,7 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
+from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -515,9 +516,12 @@ class LSTMStateTuple(_LSTMStateTuple):
return c.dtype
+# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """Basic LSTM recurrent network cell.
+ """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+
+ Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
@@ -531,6 +535,10 @@ class BasicLSTMCell(LayerRNNCell):
that follows.
"""
+ @deprecated(None, "This class is deprecated, please use "
+ "tf.nn.rnn_cell.LSTMCell, which supports all the feature "
+ "this cell currently has. Please replace the existing code "
+ "with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').")
def __init__(self,
num_units,
forget_bias=1.0,