aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 63155faf1e..b5d81b7caa 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -140,6 +140,20 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.156736, 0.156736]])
+ def testSRUCell(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run(
+ [g], {x.name: np.array([[1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])})
+ # Smoke test
+ self.assertAllClose(res[0], [[0.509682, 0.509682]])
+
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype