diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/ops/rnn_cell.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 83 |
1 files changed, 83 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 83e8c2777f..acba77f0e1 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): return new_attns, new_attn_states +class HighwayWrapper(core_rnn_cell.RNNCell): + """RNNCell wrapper that adds highway connection on cell input and output. + + Based on: + R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", + arXiv preprint arXiv:1505.00387, 2015. + https://arxiv.org/abs/1505.00387 + """ + + def __init__(self, cell, + couple_carry_transform_gates=True, + carry_bias_init=1.0): + """Constructs a `HighwayWrapper` for `cell`. + + Args: + cell: An instance of `RNNCell`. + couple_carry_transform_gates: boolean, should the Carry and Transform gate + be coupled. + carry_bias_init: float, carry gates bias initialization. + """ + self._cell = cell + self._couple_carry_transform_gates = couple_carry_transform_gates + self._carry_bias_init = carry_bias_init + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self._cell.zero_state(batch_size, dtype) + + def _highway(self, inp, out): + input_size = inp.get_shape().with_rank(2)[1].value + carry_weight = vs.get_variable("carry_w", [input_size, input_size]) + carry_bias = vs.get_variable( + "carry_b", [input_size], + initializer=init_ops.constant_initializer( + self._carry_bias_init)) + carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) + if self._couple_carry_transform_gates: + transform = 1 - carry + else: + transform_weight = vs.get_variable("transform_w", + [input_size, input_size]) + transform_bias = vs.get_variable( + "transform_b", [input_size], + initializer=init_ops.constant_initializer( + -self._carry_bias_init)) + transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp, + transform_weight, + transform_bias)) + return inp * carry + out * transform + + def __call__(self, inputs, state, scope=None): + """Run the cell and add its inputs to its outputs. + + Args: + inputs: cell inputs. + state: cell state. + scope: optional cell scope. + + Returns: + Tuple of cell outputs and new state. + + Raises: + TypeError: If cell inputs and outputs have different structure (type). + ValueError: If cell inputs and outputs have different structure (value). + """ + outputs, new_state = self._cell(inputs, state, scope=scope) + nest.assert_same_structure(inputs, outputs) + # Ensure shapes match + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + nest.map_structure(assert_shape_match, inputs, outputs) + res_outputs = nest.map_structure(self._highway, inputs, outputs) + return (res_outputs, new_state) + + class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): """LSTM unit with layer normalization and recurrent dropout. |