aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/rnn_cell_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/rnn_cell_impl.py')
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index f7854e86c0..304b6ae665 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -786,13 +786,18 @@ class DropoutWrapper(RNNCell):
class ResidualWrapper(RNNCell):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
- def __init__(self, cell):
+ def __init__(self, cell, residual_fn=None):
"""Constructs a `ResidualWrapper` for `cell`.
Args:
cell: An instance of `RNNCell`.
+ residual_fn: (Optional) The function to map raw cell inputs and raw cell
+ outputs to the actual cell outputs of the residual network.
+ Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
+ and outputs.
"""
self._cell = cell
+ self._residual_fn = residual_fn
@property
def state_size(self):
@@ -807,7 +812,7 @@ class ResidualWrapper(RNNCell):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
- """Run the cell and add its inputs to its outputs.
+ """Run the cell and then apply the residual_fn on its inputs to its outputs.
Args:
inputs: cell inputs.
@@ -822,13 +827,14 @@ class ResidualWrapper(RNNCell):
ValueError: If cell inputs and outputs have different structure (value).
"""
outputs, new_state = self._cell(inputs, state, scope=scope)
- nest.assert_same_structure(inputs, outputs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
- nest.map_structure(assert_shape_match, inputs, outputs)
- res_outputs = nest.map_structure(
- lambda inp, out: inp + out, inputs, outputs)
+ def default_residual_fn(inputs, outputs):
+ nest.assert_same_structure(inputs, outputs)
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ return nest.map_structure(lambda inp, out: inp + out, inputs, outputs)
+ res_outputs = (self._residual_fn or default_residual_fn)(inputs, outputs)
return (res_outputs, new_state)