aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/ops/rnn_cell.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py83
1 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 83e8c2777f..acba77f0e1 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
return new_attns, new_attn_states
+class HighwayWrapper(core_rnn_cell.RNNCell):
+ """RNNCell wrapper that adds highway connection on cell input and output.
+
+ Based on:
+ R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
+ arXiv preprint arXiv:1505.00387, 2015.
+ https://arxiv.org/abs/1505.00387
+ """
+
+ def __init__(self, cell,
+ couple_carry_transform_gates=True,
+ carry_bias_init=1.0):
+ """Constructs a `HighwayWrapper` for `cell`.
+
+ Args:
+ cell: An instance of `RNNCell`.
+ couple_carry_transform_gates: boolean, should the Carry and Transform gate
+ be coupled.
+ carry_bias_init: float, carry gates bias initialization.
+ """
+ self._cell = cell
+ self._couple_carry_transform_gates = couple_carry_transform_gates
+ self._carry_bias_init = carry_bias_init
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def _highway(self, inp, out):
+ input_size = inp.get_shape().with_rank(2)[1].value
+ carry_weight = vs.get_variable("carry_w", [input_size, input_size])
+ carry_bias = vs.get_variable(
+ "carry_b", [input_size],
+ initializer=init_ops.constant_initializer(
+ self._carry_bias_init))
+ carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
+ if self._couple_carry_transform_gates:
+ transform = 1 - carry
+ else:
+ transform_weight = vs.get_variable("transform_w",
+ [input_size, input_size])
+ transform_bias = vs.get_variable(
+ "transform_b", [input_size],
+ initializer=init_ops.constant_initializer(
+ -self._carry_bias_init))
+ transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
+ transform_weight,
+ transform_bias))
+ return inp * carry + out * transform
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and add its inputs to its outputs.
+
+ Args:
+ inputs: cell inputs.
+ state: cell state.
+ scope: optional cell scope.
+
+ Returns:
+ Tuple of cell outputs and new state.
+
+ Raises:
+ TypeError: If cell inputs and outputs have different structure (type).
+ ValueError: If cell inputs and outputs have different structure (value).
+ """
+ outputs, new_state = self._cell(inputs, state, scope=scope)
+ nest.assert_same_structure(inputs, outputs)
+ # Ensure shapes match
+ def assert_shape_match(inp, out):
+ inp.get_shape().assert_is_compatible_with(out.get_shape())
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ res_outputs = nest.map_structure(self._highway, inputs, outputs)
+ return (res_outputs, new_state)
+
+
class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
"""LSTM unit with layer normalization and recurrent dropout.