aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-01 11:02:10 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-01 12:27:47 -0700
commit03327190420dd5b1c34a5ffdd0000aff40980ed5 (patch)
tree0ea0fe57d8c2932a63541760a9a4d9d854a93420 /tensorflow
parent24b049877b8a056200f1c0d125d345c0af637aa1 (diff)
Adding HighwayWrapper for rnn cell that creates a highway skip connection between the cell's input and output activations.
Change: 154741723
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/rnn/__init__.py1
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py24
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py83
3 files changed, 108 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/__init__.py b/tensorflow/contrib/rnn/__init__.py
index ce1ed7f491..a744878124 100644
--- a/tensorflow/contrib/rnn/__init__.py
+++ b/tensorflow/contrib/rnn/__init__.py
@@ -43,6 +43,7 @@ See @{$python/contrib.rnn} guide.
@@BidirectionalGridLSTMCell
@@NASCell
@@PhasedLSTMCell
+@@HighwayWrapper
### RNNCell wrappers
@@AttentionCellWrapper
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 8b40fc068f..55fd7e7a51 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -882,6 +882,30 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
+ def testHighwayWrapper(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "base_cell", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 3])
+ base_cell = core_rnn_cell_impl.GRUCell(3)
+ g, m_new = base_cell(x, m)
+ with variable_scope.variable_scope(
+ "hw_cell", initializer=init_ops.constant_initializer(0.5)):
+ hw_cell = rnn_cell.HighwayWrapper(
+ core_rnn_cell_impl.GRUCell(3), carry_bias_init=-100.0)
+ g_res, m_new_res = hw_cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g, g_res, m_new, m_new_res], {
+ x: np.array([[1., 1., 1.]]),
+ m: np.array([[0.1, 0.1, 0.1]])
+ })
+ # As carry_bias_init is very negative, the carry gate is 'open' and the
+ # transform gate is 'closed'. This means the output equals the input.
+ self.assertAllClose(res[1], res[0])
+ # States are left untouched
+ self.assertAllClose(res[2], res[3])
+
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 83e8c2777f..acba77f0e1 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1157,6 +1157,89 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
return new_attns, new_attn_states
+class HighwayWrapper(core_rnn_cell.RNNCell):
+ """RNNCell wrapper that adds highway connection on cell input and output.
+
+ Based on:
+ R. K. Srivastava, K. Greff, and J. Schmidhuber, "Highway networks",
+ arXiv preprint arXiv:1505.00387, 2015.
+ https://arxiv.org/abs/1505.00387
+ """
+
+ def __init__(self, cell,
+ couple_carry_transform_gates=True,
+ carry_bias_init=1.0):
+ """Constructs a `HighwayWrapper` for `cell`.
+
+ Args:
+ cell: An instance of `RNNCell`.
+ couple_carry_transform_gates: boolean, should the Carry and Transform gate
+ be coupled.
+ carry_bias_init: float, carry gates bias initialization.
+ """
+ self._cell = cell
+ self._couple_carry_transform_gates = couple_carry_transform_gates
+ self._carry_bias_init = carry_bias_init
+
+ @property
+ def state_size(self):
+ return self._cell.state_size
+
+ @property
+ def output_size(self):
+ return self._cell.output_size
+
+ def zero_state(self, batch_size, dtype):
+ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
+ return self._cell.zero_state(batch_size, dtype)
+
+ def _highway(self, inp, out):
+ input_size = inp.get_shape().with_rank(2)[1].value
+ carry_weight = vs.get_variable("carry_w", [input_size, input_size])
+ carry_bias = vs.get_variable(
+ "carry_b", [input_size],
+ initializer=init_ops.constant_initializer(
+ self._carry_bias_init))
+ carry = math_ops.sigmoid(nn_ops.xw_plus_b(inp, carry_weight, carry_bias))
+ if self._couple_carry_transform_gates:
+ transform = 1 - carry
+ else:
+ transform_weight = vs.get_variable("transform_w",
+ [input_size, input_size])
+ transform_bias = vs.get_variable(
+ "transform_b", [input_size],
+ initializer=init_ops.constant_initializer(
+ -self._carry_bias_init))
+ transform = math_ops.sigmoid(nn_ops.xw_plus_b(inp,
+ transform_weight,
+ transform_bias))
+ return inp * carry + out * transform
+
+ def __call__(self, inputs, state, scope=None):
+ """Run the cell and add its inputs to its outputs.
+
+ Args:
+ inputs: cell inputs.
+ state: cell state.
+ scope: optional cell scope.
+
+ Returns:
+ Tuple of cell outputs and new state.
+
+ Raises:
+ TypeError: If cell inputs and outputs have different structure (type).
+ ValueError: If cell inputs and outputs have different structure (value).
+ """
+ outputs, new_state = self._cell(inputs, state, scope=scope)
+ nest.assert_same_structure(inputs, outputs)
+ # Ensure shapes match
+ def assert_shape_match(inp, out):
+ inp.get_shape().assert_is_compatible_with(out.get_shape())
+ nest.map_structure(assert_shape_match, inputs, outputs)
+ res_outputs = nest.map_structure(self._highway, inputs, outputs)
+ return (res_outputs, new_state)
+
+
class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
"""LSTM unit with layer normalization and recurrent dropout.