aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py150
1 files changed, 122 insertions, 28 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 42806ba6ec..f481726d54 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -34,6 +34,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -48,6 +51,7 @@ from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import nest
+from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
@@ -335,7 +339,8 @@ class BasicRNNCell(LayerRNNCell):
Args:
num_units: int, The number of units in the RNN cell.
- activation: Nonlinearity to use. Default: `tanh`.
+ activation: Nonlinearity to use. Default: `tanh`. It could also be string
+ that is within Keras activation function names.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
@@ -344,6 +349,8 @@ class BasicRNNCell(LayerRNNCell):
cases.
dtype: Default dtype of the layer (default of `None` means use the type
of the first input). Required when `build` is called before `call`.
+ **kwargs: Dict, keyword named properties for common layer attributes, like
+ `trainable` etc when constructing the cell from configs of get_config().
"""
def __init__(self,
@@ -351,14 +358,19 @@ class BasicRNNCell(LayerRNNCell):
activation=None,
reuse=None,
name=None,
- dtype=None):
- super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+ dtype=None,
+ **kwargs):
+ super(BasicRNNCell, self).__init__(
+ _reuse=reuse, name=name, dtype=dtype, **kwargs)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
- self._activation = activation or math_ops.tanh
+ if activation:
+ self._activation = activations.get(activation)
+ else:
+ self._activation = math_ops.tanh
@property
def state_size(self):
@@ -368,12 +380,13 @@ class BasicRNNCell(LayerRNNCell):
def output_size(self):
return self._num_units
+ @tf_utils.shape_type_conversion
def build(self, inputs_shape):
- if inputs_shape[1].value is None:
+ if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
- input_depth = inputs_shape[1].value
+ input_depth = inputs_shape[-1]
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, self._num_units])
@@ -393,6 +406,15 @@ class BasicRNNCell(LayerRNNCell):
output = self._activation(gate_inputs)
return output, output
+ def get_config(self):
+ config = {
+ "num_units": self._num_units,
+ "activation": activations.serialize(self._activation),
+ "reuse": self._reuse,
+ }
+ base_config = super(BasicRNNCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export("nn.rnn_cell.GRUCell")
class GRUCell(LayerRNNCell):
@@ -412,6 +434,8 @@ class GRUCell(LayerRNNCell):
cases.
dtype: Default dtype of the layer (default of `None` means use the type
of the first input). Required when `build` is called before `call`.
+ **kwargs: Dict, keyword named properties for common layer attributes, like
+ `trainable` etc when constructing the cell from configs of get_config().
"""
def __init__(self,
@@ -421,16 +445,21 @@ class GRUCell(LayerRNNCell):
kernel_initializer=None,
bias_initializer=None,
name=None,
- dtype=None):
- super(GRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+ dtype=None,
+ **kwargs):
+ super(GRUCell, self).__init__(
+ _reuse=reuse, name=name, dtype=dtype, **kwargs)
# Inputs must be 2-dimensional.
self.input_spec = base_layer.InputSpec(ndim=2)
self._num_units = num_units
- self._activation = activation or math_ops.tanh
- self._kernel_initializer = kernel_initializer
- self._bias_initializer = bias_initializer
+ if activation:
+ self._activation = activations.get(activation)
+ else:
+ self._activation = math_ops.tanh
+ self._kernel_initializer = initializers.get(kernel_initializer)
+ self._bias_initializer = initializers.get(bias_initializer)
@property
def state_size(self):
@@ -440,12 +469,13 @@ class GRUCell(LayerRNNCell):
def output_size(self):
return self._num_units
+ @tf_utils.shape_type_conversion
def build(self, inputs_shape):
- if inputs_shape[1].value is None:
+ if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
- input_depth = inputs_shape[1].value
+ input_depth = inputs_shape[-1]
self._gate_kernel = self.add_variable(
"gates/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[input_depth + self._num_units, 2 * self._num_units],
@@ -491,6 +521,17 @@ class GRUCell(LayerRNNCell):
new_h = u * state + (1 - u) * c
return new_h, new_h
+ def get_config(self):
+ config = {
+ "num_units": self._num_units,
+ "initializer": initializers.serialize(self._initializer),
+ "kernel_initializer": initializers.serialize(self._kernel_initializer),
+ "activation": activations.serialize(self._activation),
+ "reuse": self._reuse,
+ }
+ base_config = super(GRUCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
_LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h"))
@@ -515,9 +556,12 @@ class LSTMStateTuple(_LSTMStateTuple):
return c.dtype
+# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """Basic LSTM recurrent network cell.
+ """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+
+ Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
@@ -527,10 +571,14 @@ class BasicLSTMCell(LayerRNNCell):
It does not allow cell clipping, a projection layer, and does not
use peep-hole connections: it is the basic baseline.
- For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
+ For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell`
that follows.
"""
+ @deprecated(None, "This class is deprecated, please use "
+ "tf.nn.rnn_cell.LSTMCell, which supports all the feature "
+ "this cell currently has. Please replace the existing code "
+ "with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').")
def __init__(self,
num_units,
forget_bias=1.0,
@@ -538,7 +586,8 @@ class BasicLSTMCell(LayerRNNCell):
activation=None,
reuse=None,
name=None,
- dtype=None):
+ dtype=None,
+ **kwargs):
"""Initialize the basic LSTM cell.
Args:
@@ -549,7 +598,8 @@ class BasicLSTMCell(LayerRNNCell):
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. The latter behavior will soon be deprecated.
- activation: Activation function of the inner states. Default: `tanh`.
+ activation: Activation function of the inner states. Default: `tanh`. It
+ could also be string that is within Keras activation function names.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
@@ -558,11 +608,14 @@ class BasicLSTMCell(LayerRNNCell):
cases.
dtype: Default dtype of the layer (default of `None` means use the type
of the first input). Required when `build` is called before `call`.
+ **kwargs: Dict, keyword named properties for common layer attributes, like
+ `trainable` etc when constructing the cell from configs of get_config().
When restoring from CudnnLSTM-trained checkpoints, must use
`CudnnCompatibleLSTMCell` instead.
"""
- super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+ super(BasicLSTMCell, self).__init__(
+ _reuse=reuse, name=name, dtype=dtype, **kwargs)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@@ -573,7 +626,10 @@ class BasicLSTMCell(LayerRNNCell):
self._num_units = num_units
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
- self._activation = activation or math_ops.tanh
+ if activation:
+ self._activation = activations.get(activation)
+ else:
+ self._activation = math_ops.tanh
@property
def state_size(self):
@@ -584,12 +640,13 @@ class BasicLSTMCell(LayerRNNCell):
def output_size(self):
return self._num_units
+ @tf_utils.shape_type_conversion
def build(self, inputs_shape):
- if inputs_shape[1].value is None:
+ if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
- input_depth = inputs_shape[1].value
+ input_depth = inputs_shape[-1]
h_depth = self._num_units
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
@@ -647,6 +704,17 @@ class BasicLSTMCell(LayerRNNCell):
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
+ def get_config(self):
+ config = {
+ "num_units": self._num_units,
+ "forget_bias": self._forget_bias,
+ "state_is_tuple": self._state_is_tuple,
+ "activation": activations.serialize(self._activation),
+ "reuse": self._reuse,
+ }
+ base_config = super(BasicLSTMCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export("nn.rnn_cell.LSTMCell")
class LSTMCell(LayerRNNCell):
@@ -676,7 +744,7 @@ class LSTMCell(LayerRNNCell):
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=None, num_proj_shards=None,
forget_bias=1.0, state_is_tuple=True,
- activation=None, reuse=None, name=None, dtype=None):
+ activation=None, reuse=None, name=None, dtype=None, **kwargs):
"""Initialize the parameters for an LSTM cell.
Args:
@@ -702,7 +770,8 @@ class LSTMCell(LayerRNNCell):
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. If False, they are concatenated
along the column axis. This latter behavior will soon be deprecated.
- activation: Activation function of the inner states. Default: `tanh`.
+ activation: Activation function of the inner states. Default: `tanh`. It
+ could also be string that is within Keras activation function names.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
@@ -711,11 +780,14 @@ class LSTMCell(LayerRNNCell):
cases.
dtype: Default dtype of the layer (default of `None` means use the type
of the first input). Required when `build` is called before `call`.
+ **kwargs: Dict, keyword named properties for common layer attributes, like
+ `trainable` etc when constructing the cell from configs of get_config().
When restoring from CudnnLSTM-trained checkpoints, use
`CudnnCompatibleLSTMCell` instead.
"""
- super(LSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)
+ super(LSTMCell, self).__init__(
+ _reuse=reuse, name=name, dtype=dtype, **kwargs)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@@ -731,14 +803,17 @@ class LSTMCell(LayerRNNCell):
self._num_units = num_units
self._use_peepholes = use_peepholes
self._cell_clip = cell_clip
- self._initializer = initializer
+ self._initializer = initializers.get(initializer)
self._num_proj = num_proj
self._proj_clip = proj_clip
self._num_unit_shards = num_unit_shards
self._num_proj_shards = num_proj_shards
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
- self._activation = activation or math_ops.tanh
+ if activation:
+ self._activation = activations.get(activation)
+ else:
+ self._activation = math_ops.tanh
if num_proj:
self._state_size = (
@@ -759,12 +834,13 @@ class LSTMCell(LayerRNNCell):
def output_size(self):
return self._output_size
+ @tf_utils.shape_type_conversion
def build(self, inputs_shape):
- if inputs_shape[1].value is None:
+ if inputs_shape[-1] is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
- input_depth = inputs_shape[1].value
+ input_depth = inputs_shape[-1]
h_depth = self._num_units if self._num_proj is None else self._num_proj
maybe_partitioner = (
partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
@@ -878,6 +954,24 @@ class LSTMCell(LayerRNNCell):
array_ops.concat([c, m], 1))
return m, new_state
+ def get_config(self):
+ config = {
+ "num_units": self._num_units,
+ "use_peepholes": self._use_peepholes,
+ "cell_clip": self._cell_clip,
+ "initializer": initializers.serialize(self._initializer),
+ "num_proj": self._num_proj,
+ "proj_clip": self._proj_clip,
+ "num_unit_shards": self._num_unit_shards,
+ "num_proj_shards": self._num_proj_shards,
+ "forget_bias": self._forget_bias,
+ "state_is_tuple": self._state_is_tuple,
+ "activation": activations.serialize(self._activation),
+ "reuse": self._reuse,
+ }
+ base_config = super(LSTMCell, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs):
ix = [0]