diff options
author | 2017-05-01 11:02:10 -0800 | |
---|---|---|
committer | 2017-05-01 12:27:47 -0700 | |
commit | 03327190420dd5b1c34a5ffdd0000aff40980ed5 (patch) | |
tree | 0ea0fe57d8c2932a63541760a9a4d9d854a93420 /tensorflow | |
parent | 24b049877b8a056200f1c0d125d345c0af637aa1 (diff) |
Adding HighwayWrapper for rnn cell that creates a highway skip connection between the cell's input and output activations.
Change: 154741723
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/rnn/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 24 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 83 |
3 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py index ce1ed7f491..a744878124 100644 --- a/tensorflow/contrib/rnn/__init__.py +++ b/tensorflow/contrib/rnn/__init__.py @@ -43,6 +43,7 @@ See @{$python/contrib.rnn} guide. @@BidirectionalGridLSTMCell @@NASCell @@PhasedLSTMCell +@@HighwayWrapper ### RNNCell wrappers @@AttentionCellWrapper diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 8b40fc068f..55fd7e7a51 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -882,6 +882,30 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].c, expected_state_c) self.assertAllClose(res[1].h, expected_state_h) + def testHighwayWrapper(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "base_cell", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 3]) + base_cell = core_rnn_cell_impl.GRUCell(3) + g, m_new = base_cell(x, m) + with variable_scope.variable_scope( + "hw_cell", initializer=init_ops.constant_initializer(0.5)): + hw_cell = rnn_cell.HighwayWrapper( + core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0) + g_res, m_new_res = hw_cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g, g_res, m_new, m_new_res], { + x: np.array([[1., 1., 1.]]), + m: np.array([[0.1, 0.1, 0.1]]) + }) + # As carry_bias_init is very negative, the carry gate is 'open' and the + # transform gate is 'closed'. This means the output equals the input. + self.assertAllClose(res[1], res[0]) + # States are left untouched + self.assertAllClose(res[2], res[3]) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 83e8c2777f..acba77f0e1 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell): return new_attns, new_attn_states +class HighwayWrapper(core_rnn_cell.RNNCell): + """RNNCell wrapper that adds highway connection on cell input and output. + + Based on: + R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks", + arXiv preprint arXiv:1505.00387, 2015. + https://arxiv.org/abs/1505.00387 + """ + + def __init__(self, cell, + couple_carry_transform_gates=True, + carry_bias_init=1.0): + """Constructs a `HighwayWrapper` for `cell`. + + Args: + cell: An instance of `RNNCell`. + couple_carry_transform_gates: boolean, should the Carry and Transform gate + be coupled. + carry_bias_init: float, carry gates bias initialization. + """ + self._cell = cell + self._couple_carry_transform_gates = couple_carry_transform_gates + self._carry_bias_init = carry_bias_init + + @property + def state_size(self): + return self._cell.state_size + + @property + def output_size(self): + return self._cell.output_size + + def zero_state(self, batch_size, dtype): + with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): + return self._cell.zero_state(batch_size, dtype) + + def _highway(self, inp, out): + input_size = inp.get_shape().with_rank(2)[1].value + carry_weight = vs.get_variable("carry_w", [input_size, input_size]) + carry_bias = vs.get_variable( + "carry_b", [input_size], + initializer=init_ops.constant_initializer( + self._carry_bias_init)) + carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias)) + if self._couple_carry_transform_gates: + transform = 1 - carry + else: + transform_weight = vs.get_variable("transform_w", + [input_size, input_size]) + transform_bias = vs.get_variable( + "transform_b", [input_size], + initializer=init_ops.constant_initializer( + -self._carry_bias_init)) + transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp, + transform_weight, + transform_bias)) + return inp * carry + out * transform + + def __call__(self, inputs, state, scope=None): + """Run the cell and add its inputs to its outputs. + + Args: + inputs: cell inputs. + state: cell state. + scope: optional cell scope. + + Returns: + Tuple of cell outputs and new state. + + Raises: + TypeError: If cell inputs and outputs have different structure (type). + ValueError: If cell inputs and outputs have different structure (value). + """ + outputs, new_state = self._cell(inputs, state, scope=scope) + nest.assert_same_structure(inputs, outputs) + # Ensure shapes match + def assert_shape_match(inp, out): + inp.get_shape().assert_is_compatible_with(out.get_shape()) + nest.map_structure(assert_shape_match, inputs, outputs) + res_outputs = nest.map_structure(self._highway, inputs, outputs) + return (res_outputs, new_state) + + class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell): """LSTM unit with layer normalization and recurrent dropout. |