diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 63155faf1e..b5d81b7caa 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -140,6 +140,20 @@ class RNNCellTest(test.TestCase): # Smoke test self.assertAllClose(res[0], [[0.156736, 0.156736]]) + def testSRUCell(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + m = array_ops.zeros([1, 2]) + g, _ = contrib_rnn_cell.SRUCell(2)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g], {x.name: np.array([[1., 1.]]), + m.name: np.array([[0.1, 0.1]])}) + # Smoke test + self.assertAllClose(res[0], [[0.509682, 0.509682]]) + def testBasicLSTMCell(self): for dtype in [dtypes.float16, dtypes.float32]: np_dtype = dtype.as_numpy_dtype |