aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Po-Hsien Chu <stegben.benjamin@gmail.com>2018-01-30 05:33:48 +0800
committerGravatar ebrevdo <ebrevdo@users.noreply.github.com>2018-01-29 13:33:48 -0800
commitea8bd26b997ec75ca2b8eb48b2ffe4e3d0e7c855 (patch)
tree9c8b55542bded72ab7581cc04df8a27d93ee3d5c
parent28c3c5dd38e3b397c2cf0acdaa6388dcbf0349f7 (diff)
remove SRU num_units == x.shape[-1] restriction (#16404)
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py14
-rw-r--r--tensorflow/contrib/rnn/python/ops/rnn_cell.py24
2 files changed, 18 insertions, 20 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index cafeb56ad8..5711f41cc3 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -153,6 +153,20 @@ class RNNCellTest(test.TestCase):
m.name: np.array([[0.1, 0.1]])})
# Smoke test
self.assertAllClose(res[0], [[0.509682, 0.509682]])
+
+ def testSRUCellWithDiffSize(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g], {x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ # Smoke test
+ self.assertAllClose(res[0], [[0.55255556, 0.55255556]])
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
index 8adf5dce6e..5fee2e93e4 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py
@@ -2729,25 +2729,9 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
input_depth = inputs_shape[1].value
- # Here the contributor believes that the following constraints
- # are implied. The reasoning is explained here with reference to
- # the paper https://arxiv.org/pdf/1709.02755.pdf upon which this
- # implementation is based.
- # In section 2.1 Equation 5, specifically:
- # h_t = r_t \odot g(c_t) + (1 - r_t) \odot x_t
- # the pointwise operation between r_t and x_t means they have
- # the same shape (since we are implementing an RNN cell, braodcasting
- # does not happen to input of a single timestep); by the same
- # reasons, x_t has the same shape as h_t, essentially mandating that
- # input_depth = unit_num.
- if input_depth != self._num_units:
- raise ValueError("SRU requires input_depth == num_units, got "
- "input_depth = %s, num_units = %s" % (input_depth,
- self._num_units))
-
self._kernel = self.add_variable(
rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
- shape=[input_depth, 3 * self._num_units])
+ shape=[input_depth, 4 * self._num_units])
self._bias = self.add_variable(
rnn_cell_impl._BIAS_VARIABLE_NAME,
@@ -2760,8 +2744,8 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
"""Simple recurrent unit (SRU) with num_units cells."""
U = math_ops.matmul(inputs, self._kernel)
- x_bar, f_intermediate, r_intermediate = array_ops.split(
- value=U, num_or_size_splits=3, axis=1)
+ x_bar, f_intermediate, r_intermediate, x_tx = array_ops.split(
+ value=U, num_or_size_splits=4, axis=1)
f_r = math_ops.sigmoid(
nn_ops.bias_add(
@@ -2769,7 +2753,7 @@ class SRUCell(rnn_cell_impl._LayerRNNCell):
f, r = array_ops.split(value=f_r, num_or_size_splits=2, axis=1)
c = f * state + (1.0 - f) * x_bar
- h = r * self._activation(c) + (1.0 - r) * inputs
+ h = r * self._activation(c) + (1.0 - r) * x_tx
return h, c