diff options
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 150 |
1 files changed, 122 insertions, 28 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 42806ba6ec..f481726d54 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -34,6 +34,9 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers +from tensorflow.python.keras.utils import tf_utils from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -48,6 +51,7 @@ from tensorflow.python.ops import variables as tf_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util import nest +from tensorflow.python.util.deprecation import deprecated from tensorflow.python.util.tf_export import tf_export @@ -335,7 +339,8 @@ class BasicRNNCell(LayerRNNCell): Args: num_units: int, The number of units in the RNN cell. - activation: Nonlinearity to use. Default: `tanh`. + activation: Nonlinearity to use. Default: `tanh`. It could also be string + that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -344,6 +349,8 @@ class BasicRNNCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). """ def __init__(self, @@ -351,14 +358,19 @@ class BasicRNNCell(LayerRNNCell): activation=None, reuse=None, name=None, - dtype=None): - super(BasicRNNCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + dtype=None, + **kwargs): + super(BasicRNNCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh @property def state_size(self): @@ -368,12 +380,13 @@ class BasicRNNCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, self._num_units]) @@ -393,6 +406,15 @@ class BasicRNNCell(LayerRNNCell): output = self._activation(gate_inputs) return output, output + def get_config(self): + config = { + "num_units": self._num_units, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicRNNCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @tf_export("nn.rnn_cell.GRUCell") class GRUCell(LayerRNNCell): @@ -412,6 +434,8 @@ class GRUCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). """ def __init__(self, @@ -421,16 +445,21 @@ class GRUCell(LayerRNNCell): kernel_initializer=None, bias_initializer=None, name=None, - dtype=None): - super(GRUCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + dtype=None, + **kwargs): + super(GRUCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units - self._activation = activation or math_ops.tanh - self._kernel_initializer = kernel_initializer - self._bias_initializer = bias_initializer + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh + self._kernel_initializer = initializers.get(kernel_initializer) + self._bias_initializer = initializers.get(bias_initializer) @property def state_size(self): @@ -440,12 +469,13 @@ class GRUCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] self._gate_kernel = self.add_variable( "gates/%s" % _WEIGHTS_VARIABLE_NAME, shape=[input_depth + self._num_units, 2 * self._num_units], @@ -491,6 +521,17 @@ class GRUCell(LayerRNNCell): new_h = u * state + (1 - u) * c return new_h, new_h + def get_config(self): + config = { + "num_units": self._num_units, + "initializer": initializers.serialize(self._initializer), + "kernel_initializer": initializers.serialize(self._kernel_initializer), + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(GRUCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + _LSTMStateTuple = collections.namedtuple("LSTMStateTuple", ("c", "h")) @@ -515,9 +556,12 @@ class LSTMStateTuple(_LSTMStateTuple): return c.dtype +# TODO(scottzhu): Stop exporting this class in TF 2.0. @tf_export("nn.rnn_cell.BasicLSTMCell") class BasicLSTMCell(LayerRNNCell): - """Basic LSTM recurrent network cell. + """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead. + + Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. @@ -527,10 +571,14 @@ class BasicLSTMCell(LayerRNNCell): It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. - For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + For advanced models, please use the full `tf.nn.rnn_cell.LSTMCell` that follows. """ + @deprecated(None, "This class is deprecated, please use " + "tf.nn.rnn_cell.LSTMCell, which supports all the feature " + "this cell currently has. Please replace the existing code " + "with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').") def __init__(self, num_units, forget_bias=1.0, @@ -538,7 +586,8 @@ class BasicLSTMCell(LayerRNNCell): activation=None, reuse=None, name=None, - dtype=None): + dtype=None, + **kwargs): """Initialize the basic LSTM cell. Args: @@ -549,7 +598,8 @@ class BasicLSTMCell(LayerRNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -558,11 +608,14 @@ class BasicLSTMCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. """ - super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + super(BasicLSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -573,7 +626,10 @@ class BasicLSTMCell(LayerRNNCell): self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh @property def state_size(self): @@ -584,12 +640,13 @@ class BasicLSTMCell(LayerRNNCell): def output_size(self): return self._num_units + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] h_depth = self._num_units self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, @@ -647,6 +704,17 @@ class BasicLSTMCell(LayerRNNCell): new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state + def get_config(self): + config = { + "num_units": self._num_units, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(BasicLSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + @tf_export("nn.rnn_cell.LSTMCell") class LSTMCell(LayerRNNCell): @@ -676,7 +744,7 @@ class LSTMCell(LayerRNNCell): initializer=None, num_proj=None, proj_clip=None, num_unit_shards=None, num_proj_shards=None, forget_bias=1.0, state_is_tuple=True, - activation=None, reuse=None, name=None, dtype=None): + activation=None, reuse=None, name=None, dtype=None, **kwargs): """Initialize the parameters for an LSTM cell. Args: @@ -702,7 +770,8 @@ class LSTMCell(LayerRNNCell): state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. - activation: Activation function of the inner states. Default: `tanh`. + activation: Activation function of the inner states. Default: `tanh`. It + could also be string that is within Keras activation function names. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. @@ -711,11 +780,14 @@ class LSTMCell(LayerRNNCell): cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. + **kwargs: Dict, keyword named properties for common layer attributes, like + `trainable` etc when constructing the cell from configs of get_config(). When restoring from CudnnLSTM-trained checkpoints, use `CudnnCompatibleLSTMCell` instead. """ - super(LSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) + super(LSTMCell, self).__init__( + _reuse=reuse, name=name, dtype=dtype, **kwargs) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) @@ -731,14 +803,17 @@ class LSTMCell(LayerRNNCell): self._num_units = num_units self._use_peepholes = use_peepholes self._cell_clip = cell_clip - self._initializer = initializer + self._initializer = initializers.get(initializer) self._num_proj = num_proj self._proj_clip = proj_clip self._num_unit_shards = num_unit_shards self._num_proj_shards = num_proj_shards self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple - self._activation = activation or math_ops.tanh + if activation: + self._activation = activations.get(activation) + else: + self._activation = math_ops.tanh if num_proj: self._state_size = ( @@ -759,12 +834,13 @@ class LSTMCell(LayerRNNCell): def output_size(self): return self._output_size + @tf_utils.shape_type_conversion def build(self, inputs_shape): - if inputs_shape[1].value is None: + if inputs_shape[-1] is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) - input_depth = inputs_shape[1].value + input_depth = inputs_shape[-1] h_depth = self._num_units if self._num_proj is None else self._num_proj maybe_partitioner = ( partitioned_variables.fixed_size_partitioner(self._num_unit_shards) @@ -878,6 +954,24 @@ class LSTMCell(LayerRNNCell): array_ops.concat([c, m], 1)) return m, new_state + def get_config(self): + config = { + "num_units": self._num_units, + "use_peepholes": self._use_peepholes, + "cell_clip": self._cell_clip, + "initializer": initializers.serialize(self._initializer), + "num_proj": self._num_proj, + "proj_clip": self._proj_clip, + "num_unit_shards": self._num_unit_shards, + "num_proj_shards": self._num_proj_shards, + "forget_bias": self._forget_bias, + "state_is_tuple": self._state_is_tuple, + "activation": activations.serialize(self._activation), + "reuse": self._reuse, + } + base_config = super(LSTMCell, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + def _enumerated_map_structure_up_to(shallow_structure, map_fn, *args, **kwargs): ix = [0] |