diff options
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r-- | tensorflow/python/ops/rnn_cell_impl.py | 18 |
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index f7854e86c0..304b6ae665 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -786,13 +786,18 @@ class DropoutWrapper(RNNCell): class ResidualWrapper(RNNCell): """RNNCell wrapper that ensures cell inputs are added to the outputs.""" - def __init__(self, cell): + def __init__(self, cell, residual_fn=None): """Constructs a `ResidualWrapper` for `cell`. Args: cell: An instance of `RNNCell`. + residual_fn: (Optional) The function to map raw cell inputs and raw cell + outputs to the actual cell outputs of the residual network. + Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs + and outputs. """ self._cell = cell + self._residual_fn = residual_fn @property def state_size(self): @@ -807,7 +812,7 @@ class ResidualWrapper(RNNCell): return self._cell.zero_state(batch_size, dtype) def __call__(self, inputs, state, scope=None): - """Run the cell and add its inputs to its outputs. + """Run the cell and then apply the residual_fn on its inputs to its outputs. Args: inputs: cell inputs. @@ -822,13 +827,14 @@ class ResidualWrapper(RNNCell): ValueError: If cell inputs and outputs have different structure (value). """ outputs, new_state = self._cell(inputs, state, scope=scope) - nest.assert_same_structure(inputs, outputs) # Ensure shapes match def assert_shape_match(inp, out): inp.get_shape().assert_is_compatible_with(out.get_shape()) - nest.map_structure(assert_shape_match, inputs, outputs) - res_outputs = nest.map_structure( - lambda inp, out: inp + out, inputs, outputs) + def default_residual_fn(inputs, outputs): + nest.assert_same_structure(inputs, outputs) + nest.map_structure(assert_shape_match, inputs, outputs) + return nest.map_structure(lambda inp, out: inp + out, inputs, outputs) + res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs) return (res_outputs, new_state) |