diff options
author | 2018-10-03 17:05:38 -0700 | |
---|---|---|
committer | 2018-10-03 17:14:05 -0700 | |
commit | c842d38978a0babb373fe2acbb0231960aa1c1d0 (patch) | |
tree | 578f68fa70f75d2ea2633b84777426b795d334cb /tensorflow/contrib | |
parent | d340eb9f7ea46012b7ead202f4c12fb6b32cc56d (diff) |
Add MinimalRNN cell.
The implementation is based on: https://arxiv.org/pdf/1806.05394v2.pdf.
PiperOrigin-RevId: 215655857
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 72 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 116 |
2 files changed, 188 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 6689664fb9..0a27200015 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -29,6 +29,9 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.keras import initializers +from tensorflow.python.keras import testing_utils +from tensorflow.python.keras import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -40,7 +43,9 @@ from tensorflow.python.ops import rnn_cell from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test +from tensorflow.python.training import training from tensorflow.python.util import nest @@ -1115,6 +1120,73 @@ class RNNCellTest(test.TestCase): r"input size \(3\) must be divisible by number_of_groups \(2\)"): gcell(glstm_input, gcell_zero_state) + def testMinimalRNNCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope( + "root"): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.MinimalRNNCell( + units=2, + kernel_initializer=initializers.Constant(0.5)) + g, _ = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.18899589, 0.18899589]]) + with variable_scope.variable_scope( + "other"): + # Test MinimalRNN with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.MinimalRNNCell( + units=2, + kernel_initializer=initializers.Constant(0.5)) + g, _ = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.19554167, 0.19554167]]) + + def testMinimalRNNCellEndToEnd(self): + with self.cached_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = utils.to_categorical(y_train) + cell = contrib_rnn_cell.MinimalRNNCell(output_shape) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(state.shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), batch) + class LayerNormBasicLSTMCellTest(test.TestCase): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 06c481672c..59a61af7b3 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -28,6 +28,8 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import activations +from tensorflow.python.keras import initializers from tensorflow.python.layers import base as base_layer from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -3394,3 +3396,117 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell): new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) return new_h, new_state + + +class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): + """MinimalRNN cell. + + The implementation is based on: + + https://arxiv.org/pdf/1806.05394v2.pdf + + Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz. + "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal + Propagation in Recurrent Neural Networks." ICML, 2018. + + A MinimalRNN cell first projects the input to the hidden space. The new + hidden state is then calcuated as a weighted sum of the projected input and + the previous hidden state, using a single update gate. + """ + + def __init__(self, + units, + activation="tanh", + kernel_initializer="glorot_uniform", + bias_initializer="ones", + name=None, + dtype=None, + **kwargs): + """Initialize the parameters for a MinimalRNN cell. + + Args: + units: int, The number of units in the MinimalRNN cell. + activation: Nonlinearity to use in the feedforward network. Default: + `tanh`. + kernel_initializer: The initializer to use for the weight in the update + gate and feedforward network. Default: `glorot_uniform`. + bias_initializer: The initializer to use for the bias in the update + gate. Default: `ones`. + name: String, the name of the cell. + dtype: Default dtype of the cell. + **kwargs: Dict, keyword named properties for common cell attributes. + """ + super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self.units = units + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + @property + def state_size(self): + return self.units + + @property + def output_size(self): + return self.units + + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % str(inputs_shape)) + + input_size = inputs_shape[-1] + # pylint: disable=protected-access + # self._kernel contains W_x, W, V + self.kernel = self.add_weight( + name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + shape=[input_size + 2 * self.units, self.units], + initializer=self.kernel_initializer) + self.bias = self.add_weight( + name=rnn_cell_impl._BIAS_VARIABLE_NAME, + shape=[self.units], + initializer=self.bias_initializer) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Run one step of MinimalRNN. + + Args: + inputs: input Tensor, must be 2-D, `[batch, input_size]`. + state: state Tensor, must be 2-D, `[batch, state_size]`. + + Returns: + A tuple containing: + + - Output: A `2-D` tensor with shape `[batch_size, state_size]`. + - New state: A `2-D` tensor with shape `[batch_size, state_size]`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + input_size = inputs.get_shape()[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + feedforward_weight, gate_weight = array_ops.split( + value=self.kernel, + num_or_size_splits=[input_size.value, 2 * self.units], + axis=0) + + feedforward = math_ops.matmul(inputs, feedforward_weight) + feedforward = self.activation(feedforward) + + gate_inputs = math_ops.matmul( + array_ops.concat([feedforward, state], 1), gate_weight) + gate_inputs = nn_ops.bias_add(gate_inputs, self.bias) + u = math_ops.sigmoid(gate_inputs) + + new_h = u * state + (1 - u) * feedforward + return new_h, new_h |