aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-10 18:28:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 18:30:54 -0700
commit56b46370ba08c76200711f4a8d25194af1235fd5 (patch)
treee9d4a9822884671935a2047059624eaab5d0580a /tensorflow/python/ops/rnn_cell_impl.py
parent5cef54072782a9a893eda69bec30fcf79cd0086b (diff)
Checkpointable: Have RNN wrappers add their cells as dependencies
Also marks _SlimRNNCell as not checkpointable, and adds a more convenient way to tag such classes. Ideally adding a wrapper around a cell wouldn't break a checkpoint. This could look like RNN cell wrappers inheriting the dependencies of the cell they're wrapping. Possible to add that later if there's demand, or users can just add a dependency on wrapper._cell in addition to/instead of the wrapper when modifying programs. Fixes #19208. PiperOrigin-RevId: 196202366
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 67f753485b..68d22794d3 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -1005,6 +1005,8 @@ class DropoutWrapper(RNNCell):
# Set cell, variational_recurrent, seed before running the code below
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._variational_recurrent = variational_recurrent
self._seed = seed
@@ -1152,6 +1154,8 @@ class ResidualWrapper(RNNCell):
and outputs.
"""
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._residual_fn = residual_fn
@property
@@ -1207,6 +1211,8 @@ class DeviceWrapper(RNNCell):
device: A device string or function, for passing to `tf.device`.
"""
self._cell = cell
+ if isinstance(cell, checkpointable.CheckpointableBase):
+ self._track_checkpointable(self._cell, name="cell")
self._device = device
@property
@@ -1322,7 +1328,7 @@ class MultiRNNCell(RNNCell):
return cur_inp, new_states
-class _SlimRNNCell(RNNCell):
+class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable):
"""A simple wrapper for slim.rnn_cells."""
def __init__(self, cell_fn):