diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 75 |
1 files changed, 55 insertions, 20 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index ebd4564f12..46823fa364 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -995,26 +996,19 @@ class RNNCellTest(test.TestCase): output, state = cell(x, hidden) sess.run([variables.global_variables_initializer()]) - res = sess.run([output, state], { - hidden[0].name: - np.array([[[[[1.],[1.]], - [[1.],[1.]]], - [[[1.],[1.]], - [[1.],[1.]]]], - [[[[2.],[2.]], - [[2.],[2.]]], - [[[2.],[2.]], - [[2.],[2.]]]]]), - x.name: - np.array([[[[[1.],[1.]], - [[1.],[1.]]], - [[[1.],[1.]], - [[1.],[1.]]]], - [[[[2.],[2.]], - [[2.],[2.]]], - [[[2.],[2.]], - [[2.],[2.]]]]]) - }) + res = sess.run( + [output, state], { + hidden[0].name: + np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[ + 1. + ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], + [[[2.], [2.]], [[2.], [2.]]]]]), + x.name: + np.array([[[[[1.], [1.]], [[1.], [1.]]], [[[1.], [1.]], [[ + 1. + ], [1.]]]], [[[[2.], [2.]], [[2.], [2.]]], [[[2.], [2.]], + [[2.], [2.]]]]]) + }) # This is a smoke test, making sure expected values are unchanged. self.assertEqual(len(res), 2) self.assertAllClose(res[0], res[1].h) @@ -1275,6 +1269,47 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[2].c, expected_c1, 1e-5) self.assertAllClose(res[2].h, expected_h1, 1e-5) + def testBasicLSTMCellWithStateTupleLayerNorm(self): + """The results of LSTMCell and LayerNormBasicLSTMCell should be the same.""" + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + c0 = array_ops.zeros([1, 2]) + h0 = array_ops.zeros([1, 2]) + state0 = rnn_cell_impl.LSTMStateTuple(c0, h0) + c1 = array_ops.zeros([1, 2]) + h1 = array_ops.zeros([1, 2]) + state1 = rnn_cell_impl.LSTMStateTuple(c1, h1) + cell = rnn_cell_impl.MultiRNNCell([ + contrib_rnn_cell.LayerNormLSTMCell( + 2, layer_norm=True, norm_gain=1.0, norm_shift=0.0) + for _ in range(2) + ]) + h, (s0, s1) = cell(x, (state0, state1)) + sess.run([variables.global_variables_initializer()]) + res = sess.run( + [h, s0, s1], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + c1.name: 0.1 * np.asarray([[4, 5]]), + h1.name: 0.1 * np.asarray([[6, 7]]), + }) + + expected_h = np.array([[-0.38079708, 0.38079708]]) + expected_h0 = np.array([[-0.38079708, 0.38079708]]) + expected_c0 = np.array([[-1.0, 1.0]]) + expected_h1 = np.array([[-0.38079708, 0.38079708]]) + expected_c1 = np.array([[-1.0, 1.0]]) + + self.assertEqual(len(res), 3) + self.assertAllClose(res[0], expected_h, 1e-5) + self.assertAllClose(res[1].c, expected_c0, 1e-5) + self.assertAllClose(res[1].h, expected_h0, 1e-5) + self.assertAllClose(res[2].c, expected_c1, 1e-5) + self.assertAllClose(res[2].h, expected_h1, 1e-5) + def testBasicLSTMCellWithDropout(self): def _is_close(x, y, digits=4): |