aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py75
1 files changed, 55 insertions, 20 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index ebd4564f12..46823fa364 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
+from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -995,26 +996,19 @@ class RNNCellTest(test.TestCase):
output, state = cell(x, hidden)
sess.run([variables.global_variables_initializer()])
- res = sess.run([output, state], {
- hidden[0].name:
- np.array([[[[[1.],[1.]],
- [[1.],[1.]]],
- [[[1.],[1.]],
- [[1.],[1.]]]],
- [[[[2.],[2.]],
- [[2.],[2.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]]),
- x.name:
- np.array([[[[[1.],[1.]],
- [[1.],[1.]]],
- [[[1.],[1.]],
- [[1.],[1.]]]],
- [[[[2.],[2.]],
- [[2.],[2.]]],
- [[[2.],[2.]],
- [[2.],[2.]]]]])
- })
+ res = sess.run(
+ [output, state], {
+ hidden[0].name:
+ np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
+ 1.
+ ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]],
+ [[[2.], [2.]], [[2.], [2.]]]]]),
+ x.name:
+ np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[
+ 1.
+ ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]],
+ [[2.], [2.]]]]])
+ })
# This is a smoke test, making sure expected values are unchanged.
self.assertEqual(len(res), 2)
self.assertAllClose(res[0], res[1].h)
@@ -1275,6 +1269,47 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5)
+ def testBasicLSTMCellWithStateTupleLayerNorm(self):
+ """The results of LSTMCell and LayerNormBasicLSTMCell should be the same."""
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 2])
+ c0 = array_ops.zeros([1, 2])
+ h0 = array_ops.zeros([1, 2])
+ state0 = rnn_cell_impl.LSTMStateTuple(c0, h0)
+ c1 = array_ops.zeros([1, 2])
+ h1 = array_ops.zeros([1, 2])
+ state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
+ cell = rnn_cell_impl.MultiRNNCell([
+ contrib_rnn_cell.LayerNormLSTMCell(
+ 2, layer_norm=True, norm_gain=1.0, norm_shift=0.0)
+ for _ in range(2)
+ ])
+ h, (s0, s1) = cell(x, (state0, state1))
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run(
+ [h, s0, s1], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ c1.name: 0.1 * np.asarray([[4, 5]]),
+ h1.name: 0.1 * np.asarray([[6, 7]]),
+ })
+
+ expected_h = np.array([[-0.38079708, 0.38079708]])
+ expected_h0 = np.array([[-0.38079708, 0.38079708]])
+ expected_c0 = np.array([[-1.0, 1.0]])
+ expected_h1 = np.array([[-0.38079708, 0.38079708]])
+ expected_c1 = np.array([[-1.0, 1.0]])
+
+ self.assertEqual(len(res), 3)
+ self.assertAllClose(res[0], expected_h, 1e-5)
+ self.assertAllClose(res[1].c, expected_c0, 1e-5)
+ self.assertAllClose(res[1].h, expected_h0, 1e-5)
+ self.assertAllClose(res[2].c, expected_c1, 1e-5)
+ self.assertAllClose(res[2].h, expected_h1, 1e-5)
+
def testBasicLSTMCellWithDropout(self):
def _is_close(x, y, digits=4):