aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/recurrent.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent.py')
-rw-r--r--tensorflow/python/keras/layers/recurrent.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 32d25c5a65..534c0eca08 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -235,7 +236,8 @@ class RNN(Layer):
"""Base class for recurrent layers.
Arguments:
- cell: A RNN cell instance. A RNN cell is a class that has:
+ cell: A RNN cell instance or a list of RNN cell instances.
+ A RNN cell is a class that has:
- a `call(input_at_t, states_at_t)` method, returning
`(output_at_t, states_at_t_plus_1)`. The call method of the
cell can also take the optional argument `constants`, see
@@ -248,9 +250,9 @@ class RNN(Layer):
(one size per state). In this case, the first entry
(`state_size[0]`) should be the same as
the size of the cell output.
- It is also possible for `cell` to be a list of RNN cell instances,
- in which cases the cells get stacked on after the other in the RNN,
- implementing an efficient stacked RNN.
+ In the case that `cell` is a list of RNN cell instances, the cells
+ will be stacked on after the other in the RNN, implementing an
+ efficient stacked RNN.
return_sequences: Boolean. Whether to return the last output
in the output sequence, or the full sequence.
return_state: Boolean. Whether to return the last state
@@ -402,6 +404,8 @@ class RNN(Layer):
'one integer per RNN state).')
super(RNN, self).__init__(**kwargs)
self.cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self.cell, name='cell')
self.return_sequences = return_sequences
self.return_state = return_state
self.go_backwards = go_backwards