aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 16:29:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 16:35:34 -0700
commit4a00f2fc6514ad5ee60ab0a9645863fdf263499f (patch)
treee3ef5a887c31e9da97dd09339c34bcaa2cea75c1 /tensorflow/contrib
parent863f61412fcc654840c6b67473b742ea4e5e964e (diff)
Add Chaos Free Network (CFN) cell.
The implementation is based on: https://openreview.net/pdf?id=S1dIzvclg. PiperOrigin-RevId: 215824867
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py65
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py129
2 files changed, 194 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 0a27200015..aa1d7d2b01 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1120,6 +1120,71 @@ class RNNCellTest(test.TestCase):
r"input size \(3\) must be divisible by number_of_groups \(2\)"):
gcell(glstm_input, gcell_zero_state)
+ def testCFNCell(self):
+ with self.cached_session() as sess:
+ with variable_scope.variable_scope("root"):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.17188203, 0.17188203]])
+ with variable_scope.variable_scope("other"):
+ # Test CFN with input_size != num_units.
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ cell = contrib_rnn_cell.CFNCell(
+ units=2,
+ kernel_initializer=initializers.Constant(0.5))
+ g, _ = cell(x, m)
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.15535763, 0.15535763]])
+
+ def testCFNCellEndToEnd(self):
+ with self.cached_session() as sess:
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 100
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ y_train = utils.to_categorical(y_train)
+ cell = contrib_rnn_cell.CFNCell(output_shape)
+
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ predict = array_ops.placeholder(
+ dtypes.float32, shape=(None, output_shape))
+
+ outputs, state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape])
+ self.assertEqual(state.shape.as_list(), [None, output_shape])
+ loss = losses.softmax_cross_entropy(predict, state)
+ train_op = training.GradientDescentOptimizer(0.001).minimize(loss)
+
+ sess.run([variables.global_variables_initializer()])
+ _, outputs, state = sess.run(
+ [train_op, outputs, state], {inputs: x_train, predict: y_train})
+
+ self.assertEqual(len(outputs), batch)
+ self.assertEqual(len(state), batch)
+
def testMinimalRNNCell(self):
with self.cached_session() as sess:
with variable_scope.variable_scope(
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 59a61af7b3..78cea8feb4 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -3510,3 +3510,132 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell):
new_h = u * state + (1 - u) * feedforward
return new_h, new_h
+
+
+class CFNCell(rnn_cell_impl.LayerRNNCell):
+ """Chaos Free Network cell.
+
+ The implementation is based on:
+
+ https://openreview.net/pdf?id=S1dIzvclg
+
+ Thomas Laurent, James von Brecht.
+ "A recurrent neural network without chaos." ICLR, 2017.
+
+ A CFN cell first projects the input to the hidden space. The hidden state
+ goes through a contractive mapping. The new hidden state is then calcuated
+ as a linear combination of the projected input and the contracted previous
+ hidden state, using decoupled input and forget gates.
+ """
+
+ def __init__(self,
+ units,
+ activation="tanh",
+ kernel_initializer="glorot_uniform",
+ bias_initializer="ones",
+ name=None,
+ dtype=None,
+ **kwargs):
+ """Initialize the parameters for a CFN cell.
+
+ Args:
+ units: int, The number of units in the CFN cell.
+ activation: Nonlinearity to use. Default: `tanh`.
+ kernel_initializer: Initializer for the `kernel` weights
+ matrix. Default: `glorot_uniform`.
+ bias_initializer: The initializer to use for the bias in the
+ gates. Default: `ones`.
+ name: String, the name of the cell.
+ dtype: Default dtype of the cell.
+ **kwargs: Dict, keyword named properties for common cell attributes.
+ """
+ super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs)
+
+ # Inputs must be 2-dimensional.
+ self.input_spec = base_layer.InputSpec(ndim=2)
+
+ self.units = units
+ self.activation = activations.get(activation)
+ self.kernel_initializer = initializers.get(kernel_initializer)
+ self.bias_initializer = initializers.get(bias_initializer)
+
+ @property
+ def state_size(self):
+ return self.units
+
+ @property
+ def output_size(self):
+ return self.units
+
+ def build(self, inputs_shape):
+ if inputs_shape[-1] is None:
+ raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
+ % str(inputs_shape))
+
+ input_size = inputs_shape[-1]
+ # pylint: disable=protected-access
+ # `self.kernel` contains V_{\theta}, V_{\eta}, W.
+ # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}.
+ # `self.bias` contains b_{\theta}, b_{\eta}.
+ self.kernel = self.add_weight(
+ shape=[input_size, 3 * self.units],
+ name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.recurrent_kernel = self.add_weight(
+ shape=[self.units, 2 * self.units],
+ name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
+ initializer=self.kernel_initializer)
+ self.bias = self.add_weight(
+ shape=[2 * self.units],
+ name=rnn_cell_impl._BIAS_VARIABLE_NAME,
+ initializer=self.bias_initializer)
+ # pylint: enable=protected-access
+
+ self.built = True
+
+ def call(self, inputs, state):
+ """Run one step of CFN.
+
+ Args:
+ inputs: input Tensor, must be 2-D, `[batch, input_size]`.
+ state: state Tensor, must be 2-D, `[batch, state_size]`.
+
+ Returns:
+ A tuple containing:
+
+ - Output: A `2-D` tensor with shape `[batch_size, state_size]`.
+ - New state: A `2-D` tensor with shape `[batch_size, state_size]`.
+
+ Raises:
+ ValueError: If input size cannot be inferred from inputs via
+ static shape inference.
+ """
+ input_size = inputs.get_shape()[-1]
+ if input_size.value is None:
+ raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
+
+ # The variable names u, v, w, b are consistent with the notations in the
+ # original paper.
+ v, w = array_ops.split(
+ value=self.kernel,
+ num_or_size_splits=[2 * self.units, self.units],
+ axis=1)
+ u = self.recurrent_kernel
+ b = self.bias
+
+ gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v)
+ gates = nn_ops.bias_add(gates, b)
+ gates = math_ops.sigmoid(gates)
+ theta, eta = array_ops.split(value=gates,
+ num_or_size_splits=2,
+ axis=1)
+
+ proj_input = math_ops.matmul(inputs, w)
+
+ # The input gate is (1 - eta), which is different from the original paper.
+ # This is for the propose of initialization. With the default
+ # bias_initializer `ones`, the input gate is initialized to a small number.
+ new_h = theta * self.activation(state) + (1 - eta) * self.activation(
+ proj_input)
+
+ return new_h, new_h