From ea8bd26b997ec75ca2b8eb48b2ffe4e3d0e7c855 Mon Sep 17 00:00:00 2001 From: Po-Hsien Chu Date: Tue, 30 Jan 2018 05:33:48 +0800 Subject: remove SRU num_units == x.shape[-1] restriction (#16404) --- .../rnn/python/kernel_tests/core_rnn_cell_test.py | 14 +++++++++++++ tensorflow/contrib/rnn/python/ops/rnn_cell.py | 24 ++++------------------ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index cafeb56ad8..5711f41cc3 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -153,6 +153,20 @@ class RNNCellTest(test.TestCase): m.name: np.array([[0.1, 0.1]])}) # Smoke test self.assertAllClose(res[0], [[0.509682, 0.509682]]) + + def testSRUCellWithDiffSize(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 3]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g], {x.name: np.array([[1., 1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + # Smoke test + self.assertAllClose(res[0], [[0.55255556, 0.55255556]]) def testBasicLSTMCell(self): for dtype in [dtypes.float16, dtypes.float32]: diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 8adf5dce6e..5fee2e93e4 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -2729,25 +2729,9 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): input_depth = inputs_shape[1].value - # Here the contributor believes that the following constraints - # are implied. The reasoning is explained here with reference to - # the paper https://arxiv.org/pdf/1709.02755.pdf upon which this - # implementation is based. - # In section 2.1 Equation 5, specifically: - # h_t = r_t \odot g(c_t) + (1 - r_t) \odot x_t - # the pointwise operation between r_t and x_t means they have - # the same shape (since we are implementing an RNN cell, braodcasting - # does not happen to input of a single timestep); by the same - # reasons, x_t has the same shape as h_t, essentially mandating that - # input_depth = unit_num. - if input_depth != self._num_units: - raise ValueError("SRU requires input_depth == num_units, got " - "input_depth = %s, num_units = %s" % (input_depth, - self._num_units)) - self._kernel = self.add_variable( rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - shape=[input_depth, 3 * self._num_units]) + shape=[input_depth, 4 * self._num_units]) self._bias = self.add_variable( rnn_cell_impl._BIAS_VARIABLE_NAME, @@ -2760,8 +2744,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): """Simple recurrent unit (SRU) with num_units cells.""" U = math_ops.matmul(inputs, self._kernel) - x_bar, f_intermediate, r_intermediate = array_ops.split( - value=U, num_or_size_splits=3, axis=1) + x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split( + value=U, num_or_size_splits=4, axis=1) f_r = math_ops.sigmoid( nn_ops.bias_add( @@ -2769,7 +2753,7 @@ class SRUCell(rnn_cell_impl._LayerRNNCell): f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1) c = f * state + (1.0 - f) * x_bar - h = r * self._activation(c) + (1.0 - r) * inputs + h = r * self._activation(c) + (1.0 - r) * x_tx return h, c -- cgit v1.2.3