diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-10 18:28:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 18:30:54 -0700 |
commit | 56b46370ba08c76200711f4a8d25194af1235fd5 (patch) | |
tree | e9d4a9822884671935a2047059624eaab5d0580a /tensorflow/python/ops/rnn_cell_impl.py | |
parent | 5cef54072782a9a893eda69bec30fcf79cd0086b (diff) |
Checkpointable: Have RNN wrappers add their cells as dependencies
Also marks _SlimRNNCell as not checkpointable, and adds a more convenient way to tag such classes.
Ideally adding a wrapper around a cell wouldn't break a checkpoint. This could look like RNN cell wrappers inheriting the dependencies of the cell they're wrapping. Possible to add that later if there's demand, or users can just add a dependency on wrapper._cell in addition to/instead of the wrapper when modifying programs.
Fixes #19208.
PiperOrigin-RevId: 196202366
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index 67f753485b..68d22794d3 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -1005,6 +1005,8 @@ class DropoutWrapper(RNNCell): # Set cell, variational_recurrent, seed before running the code below self._cell = cell + if isinstance(cell, checkpointable.CheckpointableBase): + self._track_checkpointable(self._cell, name="cell") self._variational_recurrent = variational_recurrent self._seed = seed @@ -1152,6 +1154,8 @@ class ResidualWrapper(RNNCell): and outputs. """ self._cell = cell + if isinstance(cell, checkpointable.CheckpointableBase): + self._track_checkpointable(self._cell, name="cell") self._residual_fn = residual_fn @property @@ -1207,6 +1211,8 @@ class DeviceWrapper(RNNCell): device: A device string or function, for passing to `tf.device`. """ self._cell = cell + if isinstance(cell, checkpointable.CheckpointableBase): + self._track_checkpointable(self._cell, name="cell") self._device = device @property @@ -1322,7 +1328,7 @@ class MultiRNNCell(RNNCell): return cur_inp, new_states -class _SlimRNNCell(RNNCell): +class _SlimRNNCell(RNNCell, checkpointable.NotCheckpointable): """A simple wrapper for slim.rnn_cells.""" def __init__(self, cell_fn): |