diff options
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r-- | tensorflow/python/keras/layers/recurrent.py | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index 32d25c5a65..534c0eca08 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training.checkpointable import base as checkpointable from tensorflow.python.util.tf_export import tf_export @@ -235,7 +236,8 @@ class RNN(Layer): """Base class for recurrent layers. Arguments: - cell: A RNN cell instance. A RNN cell is a class that has: + cell: A RNN cell instance or a list of RNN cell instances. + A RNN cell is a class that has: - a `call(input_at_t, states_at_t)` method, returning `(output_at_t, states_at_t_plus_1)`. The call method of the cell can also take the optional argument `constants`, see @@ -248,9 +250,9 @@ class RNN(Layer): (one size per state). In this case, the first entry (`state_size[0]`) should be the same as the size of the cell output. - It is also possible for `cell` to be a list of RNN cell instances, - in which cases the cells get stacked on after the other in the RNN, - implementing an efficient stacked RNN. + In the case that `cell` is a list of RNN cell instances, the cells + will be stacked on after the other in the RNN, implementing an + efficient stacked RNN. return_sequences: Boolean. Whether to return the last output in the output sequence, or the full sequence. return_state: Boolean. Whether to return the last state @@ -402,6 +404,8 @@ class RNN(Layer): 'one integer per RNN state).') super(RNN, self).__init__(**kwargs) self.cell = cell + if isinstance(cell, checkpointable.CheckpointableBase): + self._track_checkpointable(self.cell, name='cell') self.return_sequences = return_sequences self.return_state = return_state self.go_backwards = go_backwards |