diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 16:29:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 16:35:34 -0700 |
commit | 4a00f2fc6514ad5ee60ab0a9645863fdf263499f (patch) | |
tree | e3ef5a887c31e9da97dd09339c34bcaa2cea75c1 /tensorflow/contrib/rnn | |
parent | 863f61412fcc654840c6b67473b742ea4e5e964e (diff) |
Add Chaos Free Network (CFN) cell.
The implementation is based on: https://openreview.net/pdf?id=S1dIzvclg.
PiperOrigin-RevId: 215824867
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 65 | ||||
-rw-r--r-- | tensorflow/contrib/rnn/python/ops/rnn_cell.py | 129 |
2 files changed, 194 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 0a27200015..aa1d7d2b01 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1120,6 +1120,71 @@ class RNNCellTest(test.TestCase): r"input size \(3\) must be divisible by number_of_groups \(2\)"): gcell(glstm_input, gcell_zero_state) + def testCFNCell(self): + with self.cached_session() as sess: + with variable_scope.variable_scope("root"): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.CFNCell( + units=2, + kernel_initializer=initializers.Constant(0.5)) + g, _ = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.17188203, 0.17188203]]) + with variable_scope.variable_scope("other"): + # Test CFN with input_size != num_units. + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + cell = contrib_rnn_cell.CFNCell( + units=2, + kernel_initializer=initializers.Constant(0.5)) + g, _ = cell(x, m) + sess.run([variables.global_variables_initializer()]) + res = sess.run([g], { + x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]]) + }) + # Smoke test + self.assertAllClose(res[0], [[0.15535763, 0.15535763]]) + + def testCFNCellEndToEnd(self): + with self.cached_session() as sess: + input_shape = 10 + output_shape = 5 + timestep = 4 + batch = 100 + (x_train, y_train), _ = testing_utils.get_test_data( + train_samples=batch, + test_samples=0, + input_shape=(timestep, input_shape), + num_classes=output_shape) + y_train = utils.to_categorical(y_train) + cell = contrib_rnn_cell.CFNCell(output_shape) + + inputs = array_ops.placeholder( + dtypes.float32, shape=(None, timestep, input_shape)) + predict = array_ops.placeholder( + dtypes.float32, shape=(None, output_shape)) + + outputs, state = rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32) + self.assertEqual(outputs.shape.as_list(), [None, timestep, output_shape]) + self.assertEqual(state.shape.as_list(), [None, output_shape]) + loss = losses.softmax_cross_entropy(predict, state) + train_op = training.GradientDescentOptimizer(0.001).minimize(loss) + + sess.run([variables.global_variables_initializer()]) + _, outputs, state = sess.run( + [train_op, outputs, state], {inputs: x_train, predict: y_train}) + + self.assertEqual(len(outputs), batch) + self.assertEqual(len(state), batch) + def testMinimalRNNCell(self): with self.cached_session() as sess: with variable_scope.variable_scope( diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 59a61af7b3..78cea8feb4 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -3510,3 +3510,132 @@ class MinimalRNNCell(rnn_cell_impl.LayerRNNCell): new_h = u * state + (1 - u) * feedforward return new_h, new_h + + +class CFNCell(rnn_cell_impl.LayerRNNCell): + """Chaos Free Network cell. + + The implementation is based on: + + https://openreview.net/pdf?id=S1dIzvclg + + Thomas Laurent, James von Brecht. + "A recurrent neural network without chaos." ICLR, 2017. + + A CFN cell first projects the input to the hidden space. The hidden state + goes through a contractive mapping. The new hidden state is then calcuated + as a linear combination of the projected input and the contracted previous + hidden state, using decoupled input and forget gates. + """ + + def __init__(self, + units, + activation="tanh", + kernel_initializer="glorot_uniform", + bias_initializer="ones", + name=None, + dtype=None, + **kwargs): + """Initialize the parameters for a CFN cell. + + Args: + units: int, The number of units in the CFN cell. + activation: Nonlinearity to use. Default: `tanh`. + kernel_initializer: Initializer for the `kernel` weights + matrix. Default: `glorot_uniform`. + bias_initializer: The initializer to use for the bias in the + gates. Default: `ones`. + name: String, the name of the cell. + dtype: Default dtype of the cell. + **kwargs: Dict, keyword named properties for common cell attributes. + """ + super(CFNCell, self).__init__(name=name, dtype=dtype, **kwargs) + + # Inputs must be 2-dimensional. + self.input_spec = base_layer.InputSpec(ndim=2) + + self.units = units + self.activation = activations.get(activation) + self.kernel_initializer = initializers.get(kernel_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + @property + def state_size(self): + return self.units + + @property + def output_size(self): + return self.units + + def build(self, inputs_shape): + if inputs_shape[-1] is None: + raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" + % str(inputs_shape)) + + input_size = inputs_shape[-1] + # pylint: disable=protected-access + # `self.kernel` contains V_{\theta}, V_{\eta}, W. + # `self.recurrent_kernel` contains U_{\theta}, U_{\eta}. + # `self.bias` contains b_{\theta}, b_{\eta}. + self.kernel = self.add_weight( + shape=[input_size, 3 * self.units], + name=rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + initializer=self.kernel_initializer) + self.recurrent_kernel = self.add_weight( + shape=[self.units, 2 * self.units], + name="recurrent_%s" % rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + initializer=self.kernel_initializer) + self.bias = self.add_weight( + shape=[2 * self.units], + name=rnn_cell_impl._BIAS_VARIABLE_NAME, + initializer=self.bias_initializer) + # pylint: enable=protected-access + + self.built = True + + def call(self, inputs, state): + """Run one step of CFN. + + Args: + inputs: input Tensor, must be 2-D, `[batch, input_size]`. + state: state Tensor, must be 2-D, `[batch, state_size]`. + + Returns: + A tuple containing: + + - Output: A `2-D` tensor with shape `[batch_size, state_size]`. + - New state: A `2-D` tensor with shape `[batch_size, state_size]`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + input_size = inputs.get_shape()[-1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + # The variable names u, v, w, b are consistent with the notations in the + # original paper. + v, w = array_ops.split( + value=self.kernel, + num_or_size_splits=[2 * self.units, self.units], + axis=1) + u = self.recurrent_kernel + b = self.bias + + gates = math_ops.matmul(state, u) + math_ops.matmul(inputs, v) + gates = nn_ops.bias_add(gates, b) + gates = math_ops.sigmoid(gates) + theta, eta = array_ops.split(value=gates, + num_or_size_splits=2, + axis=1) + + proj_input = math_ops.matmul(inputs, w) + + # The input gate is (1 - eta), which is different from the original paper. + # This is for the propose of initialization. With the default + # bias_initializer `ones`, the input gate is initialized to a small number. + new_h = theta * self.activation(state) + (1 - eta) * self.activation( + proj_input) + + return new_h, new_h |