aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 17:05:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 17:14:05 -0700
commitc842d38978a0babb373fe2acbb0231960aa1c1d0 (patch)
tree578f68fa70f75d2ea2633b84777426b795d334cb /tensorflow/contrib/rnn
parentd340eb9f7ea46012b7ead202f4c12fb6b32cc56d (diff)
Add MinimalRNN cell.
The implementation is based on: https://arxiv.org/pdf/1806.05394v2.pdf. PiperOrigin-RevId: 215655857
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py72
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py116
2 files changed, 188 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 6689664fb9..0a27200015 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -29,6 +29,9 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.keras import initializers
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.keras import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@@ -40,7 +43,9 @@ from tensorflow.python.ops import rnn_cell
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import test
+from tensorflow.python.training import training
from tensorflow.python.util import nest
@@ -1115,6 +1120,73 @@ class RNNCellTest(test.TestCase):
r"input size \(3\) must be divisible by number_of_groups \(2\)"):
gcell(glstm_input, gcell_zero_state)
+ def testMinimalRNNCell(self):
+ with self.cached_session() as sess:
+ with variable_scope.variable_scope(
+ "root"):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.MinimalRNNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.18899589, 0.18899589]])
+ with variable_scope.variable_scope(
+ "other"):
+ # Test MinimalRNN with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.MinimalRNNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.19554167, 0.19554167]])
+
+ def testMinimalRNNCellEndToEnd(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = utils.to_categorical(y_train)
+ cell = contrib_rnn_cell.MinimalRNNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
class LayerNormBasicLSTMCellTest(test.TestCase):
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 06c481672c..59a61af7b3 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -28,6 +28,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras import activations
+from tensorflow.python.keras import initializers
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -3394,3 +3396,117 @@ class IndyLSTMCell(rnn_cell_impl.LayerRNNCell):
new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h)
return new_h, new_state
+
+
+class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
+ """MinimalRNN cell.
+
+ The implementation is based on:
+
+ https://arxiv.org/pdf/1806.05394v2.pdf
+
+ Minmin Chen, Jeffrey Pennington, Samuel S. Schoenholz.
+ "Dynamical Isometry and a Mean Field Theory of RNNs: Gating Enables Signal
+ Propagation in Recurrent Neural Networks." ICML, 2018.
+
+ A MinimalRNN cell first projects the input to the hidden space. The new
+ hidden state is then calcuated as a weighted sum of the projected input and
+ the previous hidden state, using a single update gate.
+ """
+
+ def __init__(self,
+ units,
+ activation="tanh",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="ones",
+ name=None,
+ dtype=None,
+ **kwargs):
+ """Initialize the parameters for a MinimalRNN cell.
+
+ Args:
+ units: int, The number of units in the MinimalRNN cell.
+ activation: Nonlinearity to use in the feedforward network. Default:
+ `tanh`.
+ kernel_initializer: The initializer to use for the weight in the update
+ gate and feedforward network. Default: `glorot_uniform`.
+ bias_initializer: The initializer to use for the bias in the update
+ gate. Default: `ones`.
+ name: String, the name of the cell.
+ dtype: Default dtype of the cell.
+ **kwargs: Dict, keyword named properties for common cell attributes.
+ """
+ super(MinimalRNNCell, self).__init__(name=name, dtype=dtype, **kwargs)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self.units = units
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ @property
+ def state_size(self):
+ return self.units
+
+ @property
+ def output_size(self):
+ return self.units
+
+ def build(self, inputs_shape):
+ if inputs_shape[-1] is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % str(inputs_shape))
+
+ input_size = inputs_shape[-1]
+ # pylint: disable=protected-access
+ # self._kernel contains W_x, W, V
+ self.kernel = self.add_weight(
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ shape=[input_size + 2 * self.units, self.units],
+ initializer=self.kernel_initializer)
+ self.bias = self.add_weight(
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
+ shape=[self.units],
+ initializer=self.bias_initializer)
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of MinimalRNN.
+
+ Args:
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
+
+ Returns:
+ A tuple containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ input_size = inputs.get_shape()[1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ feedforward_weight, gate_weight = array_ops.split(
+ value=self.kernel,
+ num_or_size_splits=[input_size.value, 2 * self.units],
+ axis=0)
+
+ feedforward = math_ops.matmul(inputs, feedforward_weight)
+ feedforward = self.activation(feedforward)
+
+ gate_inputs = math_ops.matmul(
+ array_ops.concat([feedforward, state], 1), gate_weight)
+ gate_inputs = nn_ops.bias_add(gate_inputs, self.bias)
+ u = math_ops.sigmoid(gate_inputs)
+
+ new_h = u * state + (1 - u) * feedforward
+ return new_h, new_h