aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-02-07 14:36:00 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-07 14:39:49 -0800
commitd90054e7c0f41f4bab81df0548577a73b939a87a (patch)
treea15aea686a9d3f305e316d2a6ada0859ad8170d1 /tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
parent8461760f9f6cde8ed97507484d2a879140141032 (diff)
Merge changes from github.
PiperOrigin-RevId: 184897758
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 09527e8473..0e62b315b6 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -157,6 +157,21 @@ class RNNCellTest(test.TestCase):
# Smoke test
self.assertAllClose(res[0], [[0.509682, 0.509682]])
+ def testSRUCellWithDiffSize(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 3])
+ m = array_ops.zeros([1, 2])
+ g, _ = contrib_rnn_cell.SRUCell(2)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g], {
+ x.name: np.array([[1., 1., 1.]]),
+ m.name: np.array([[0.1, 0.1]])
+ })
+ # Smoke test
+ self.assertAllClose(res[0], [[0.55255556, 0.55255556]])
+
def testBasicLSTMCell(self):
for dtype in [dtypes.float16, dtypes.float32]:
np_dtype = dtype.as_numpy_dtype