aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-27 09:41:48 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-27 11:11:51 -0700
commitd58fe97fdefcd968c28f4cff916ba6a26e234d4f (patch)
tree0a8552b38818b93402c56f2444924fcdb2401b53
parentf644fad2850a78effe51dae63ebb97cf47700050 (diff)
Added UGRNN and Intersection RNN cells from the paper: Capacity and Trainability in Recurrent Neural Networks
Change: 151339833
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py88
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py208
2 files changed, 296 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index ec0291cd7a..431065ef0b 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -756,6 +756,94 @@ class RNNCellTest(test.TestCase):
self.assertEqual(new_h.shape[1], num_proj)
self.assertAllClose(np.concatenate(res[1], axis=1), expected_state)
+ def testUGRNNCell(self):
+ num_units = 2
+ batch_size = 3
+ expected_state_and_output = np.array(
+ [[0.13752282, 0.13752282],
+ [0.10545051, 0.10545051],
+ [0.10074195, 0.10074195]],
+ dtype=np.float32)
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "ugrnn_cell_test",
+ initializer=init_ops.constant_initializer(0.5)):
+ cell = rnn_cell.UGRNNCell(num_units=num_units)
+ inputs = constant_op.constant(
+ np.array([[1., 1., 1., 1.],
+ [2., 2., 2., 2.],
+ [3., 3., 3., 3.]],
+ dtype=np.float32),
+ dtype=dtypes.float32)
+ init_state = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_units), dtype=np.float32),
+ dtype=dtypes.float32)
+ output, state = cell(inputs, init_state)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state])
+ # This is a smoke test: Only making sure expected values didn't change.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], expected_state_and_output)
+ self.assertAllClose(res[1], expected_state_and_output)
+
+ def testIntersectionRNNCell(self):
+ num_units = 2
+ batch_size = 3
+ expected_state = np.array(
+ [[0.13752282, 0.13752282],
+ [0.10545051, 0.10545051],
+ [0.10074195, 0.10074195]],
+ dtype=np.float32)
+ expected_output = np.array(
+ [[2.00431061, 2.00431061],
+ [4.00060606, 4.00060606],
+ [6.00008249, 6.00008249]],
+ dtype=np.float32)
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "intersection_rnn_cell_test",
+ initializer=init_ops.constant_initializer(0.5)):
+ cell = rnn_cell.IntersectionRNNCell(num_units=num_units,
+ num_in_proj=num_units)
+ inputs = constant_op.constant(
+ np.array([[1., 1., 1., 1.],
+ [2., 2., 2., 2.],
+ [3., 3., 3., 3.]],
+ dtype=np.float32),
+ dtype=dtypes.float32)
+ init_state = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_units), dtype=np.float32),
+ dtype=dtypes.float32)
+ output, state = cell(inputs, init_state)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state])
+ # This is a smoke test: Only making sure expected values didn't change.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], expected_output)
+ self.assertAllClose(res[1], expected_state)
+
+ def testIntersectionRNNCellFailure(self):
+ num_units = 2
+ batch_size = 3
+ cell = rnn_cell.IntersectionRNNCell(num_units=num_units)
+ inputs = constant_op.constant(
+ np.array([[1., 1., 1., 1.],
+ [2., 2., 2., 2.],
+ [3., 3., 3., 3.]],
+ dtype=np.float32),
+ dtype=dtypes.float32)
+ init_state = constant_op.constant(
+ 0.1 * np.ones(
+ (batch_size, num_units), dtype=np.float32),
+ dtype=dtypes.float32)
+ with self.assertRaisesRegexp(
+ ValueError, "Must have input size == output size for "
+ "Intersection RNN. To fix, num_in_proj should "
+ "be set to num_units at cell init."):
+ cell(inputs, init_state)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 5d435d8653..c447b52f66 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -1434,6 +1434,214 @@ class NASCell(core_rnn_cell.RNNCell):
return new_m, new_state
+class UGRNNCell(core_rnn_cell.RNNCell):
+ """Update Gate Recurrent Neural Network (UGRNN) cell.
+
+ Compromise between a LSTM/GRU and a vanilla RNN. There is only one
+ gate, and that is to determine whether the unit should be
+ integrating or computing instantaneously. This is the recurrent
+ idea of the feedforward Highway Network.
+
+ This implements the recurrent cell from the paper:
+
+ https://arxiv.org/abs/1611.09913
+
+ Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
+ "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
+ """
+
+ def __init__(self, num_units, initializer=None, forget_bias=1.0,
+ activation=math_ops.tanh, reuse=None):
+ """Initialize the parameters for an UGRNN cell.
+
+ Args:
+ num_units: int, The number of units in the UGRNN cell
+ initializer: (optional) The initializer to use for the weight matrices.
+ forget_bias: (optional) float, default 1.0, The initial bias of the
+ forget gate, used to reduce the scale of forgetting at the beginning
+ of the training.
+ activation: (optional) Activation function of the inner states.
+ Default is `tf.tanh`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+ self._num_units = num_units
+ self._initializer = initializer
+ self._forget_bias = forget_bias
+ self._activation = activation
+ self._reuse = reuse
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Run one step of UGRNN.
+
+ Args:
+ inputs: input Tensor, 2D, batch x input size.
+ state: state Tensor, 2D, batch x num units.
+ scope: VariableScope for the created subgraph; defaults to "ugrnn_cell".
+
+ Returns:
+ new_output: batch x num units, Tensor representing the output of the UGRNN
+ after reading `inputs` when previous state was `state`. Identical to
+ `new_state`.
+ new_state: batch x num units, Tensor representing the state of the UGRNN
+ after reading `inputs` when previous state was `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ sigmoid = math_ops.sigmoid
+
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ with _checked_scope(self, scope or "ugrnn_cell",
+ initializer=self._initializer, reuse=self._reuse):
+ cell_inputs = array_ops.concat([inputs, state], 1)
+ rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True)
+
+ [g_act, c_act] = array_ops.split(
+ axis=1, num_or_size_splits=2, value=rnn_matrix)
+
+ c = self._activation(c_act)
+ g = sigmoid(g_act + self._forget_bias)
+ new_state = g * state + (1.0 - g) * c
+ new_output = new_state
+
+ return new_output, new_state
+
+
+class IntersectionRNNCell(core_rnn_cell.RNNCell):
+ """Intersection Recurrent Neural Network (+RNN) cell.
+
+ Architecture with coupled recurrent gate as well as coupled depth
+ gate, designed to improve information flow through stacked RNNs. As the
+ architecture uses depth gating, the dimensionality of the depth
+ output (y) also should not change through depth (input size == output size).
+ To achieve this, the first layer of a stacked Intersection RNN projects
+ the inputs to N (num units) dimensions. Therefore when initializing an
+ IntersectionRNNCell, one should set `num_in_proj = N` for the first layer
+ and use default settings for subsequent layers.
+
+ This implements the recurrent cell from the paper:
+
+ https://arxiv.org/abs/1611.09913
+
+ Jasmine Collins, Jascha Sohl-Dickstein, and David Sussillo.
+ "Capacity and Trainability in Recurrent Neural Networks" Proc. ICLR 2017.
+
+ The Intersection RNN is built for use in deeply stacked
+ RNNs so it may not achieve best performance with depth 1.
+ """
+
+ def __init__(self, num_units, num_in_proj=None,
+ initializer=None, forget_bias=1.0,
+ y_activation=nn_ops.relu, reuse=None):
+ """Initialize the parameters for an +RNN cell.
+
+ Args:
+ num_units: int, The number of units in the +RNN cell
+ num_in_proj: (optional) int, The input dimensionality for the RNN.
+ If creating the first layer of an +RNN, this should be set to
+ `num_units`. Otherwise, this should be set to `None` (default).
+ If `None`, dimensionality of `inputs` should be equal to `num_units`,
+ otherwise ValueError is thrown.
+ initializer: (optional) The initializer to use for the weight matrices.
+ forget_bias: (optional) float, default 1.0, The initial bias of the
+ forget gates, used to reduce the scale of forgetting at the beginning
+ of the training.
+ y_activation: (optional) Activation function of the states passed
+ through depth. Default is 'tf.nn.relu`.
+ reuse: (optional) Python boolean describing whether to reuse variables
+ in an existing scope. If not `True`, and the existing scope already has
+ the given variables, an error is raised.
+ """
+ self._num_units = num_units
+ self._initializer = initializer
+ self._forget_bias = forget_bias
+ self._num_input_proj = num_in_proj
+ self._y_activation = y_activation
+ self._reuse = reuse
+
+ @property
+ def state_size(self):
+ return self._num_units
+
+ @property
+ def output_size(self):
+ return self._num_units
+
+ def __call__(self, inputs, state, scope=None):
+ """Run one step of the Intersection RNN.
+
+ Args:
+ inputs: input Tensor, 2D, batch x input size.
+ state: state Tensor, 2D, batch x num units.
+ scope: VariableScope for the created subgraph; defaults to
+ "intersection_rnn_cell"
+
+ Returns:
+ new_y: batch x num units, Tensor representing the output of the +RNN
+ after reading `inputs` when previous state was `state`.
+ new_state: batch x num units, Tensor representing the state of the +RNN
+ after reading `inputs` when previous state was `state`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from `inputs` via
+ static shape inference.
+ ValueError: If input size != output size (these must be equal when
+ using the Intersection RNN).
+ """
+ sigmoid = math_ops.sigmoid
+ tanh = math_ops.tanh
+
+ input_size = inputs.get_shape().with_rank(2)[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ with _checked_scope(self, scope or "intersection_rnn_cell",
+ initializer=self._initializer, reuse=self._reuse):
+ # read-in projections (should be used for first layer in deep +RNN
+ # to transform size of inputs from I --> N)
+ if input_size.value != self._num_units:
+ if self._num_input_proj:
+ with vs.variable_scope("in_projection"):
+ inputs = _linear(inputs, self._num_units, True)
+ else:
+ raise ValueError("Must have input size == output size for "
+ "Intersection RNN. To fix, num_in_proj should "
+ "be set to num_units at cell init.")
+
+ n_dim = i_dim = self._num_units
+ cell_inputs = array_ops.concat([inputs, state], 1)
+ rnn_matrix = _linear(cell_inputs, 2*n_dim + 2*i_dim, True)
+
+ gh_act = rnn_matrix[:, :n_dim] # b x n
+ h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n
+ gy_act = rnn_matrix[:, 2*n_dim:2*n_dim+i_dim] # b x i
+ y_act = rnn_matrix[:, 2*n_dim+i_dim:2*n_dim+2*i_dim] # b x i
+
+ h = tanh(h_act)
+ y = self._y_activation(y_act)
+ gh = sigmoid(gh_act + self._forget_bias)
+ gy = sigmoid(gy_act + self._forget_bias)
+
+ new_state = gh * state + (1.0 - gh) * h # passed thru time
+ new_y = gy * inputs + (1.0 - gy) * y # passed thru depth
+
+ return new_y, new_state
+
+
_REGISTERED_OPS = None