aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 909c6aba2b..16b6d145e3 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -38,6 +38,9 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
+from tensorflow.python.framework import test_util
+from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
+
# pylint: enable=protected-access
@@ -358,6 +361,45 @@ class RNNCellTest(test.TestCase):
self.assertEquals(variables[2].op.name,
"root/lstm_cell/projection/kernel")
+ def testLSTMCellLayerNorm(self):
+ with self.test_session() as sess:
+ num_units = 2
+ num_proj = 3
+ batch_size = 1
+ input_size = 4
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([batch_size, input_size])
+ c = array_ops.zeros([batch_size, num_units])
+ h = array_ops.zeros([batch_size, num_proj])
+ state = rnn_cell_impl.LSTMStateTuple(c, h)
+ cell = contrib_rnn_cell.LayerNormLSTMCell(
+ num_units=num_units,
+ num_proj=num_proj,
+ forget_bias=1.0,
+ layer_norm=True,
+ norm_gain=1.0,
+ norm_shift=0.0)
+ g, out_m = cell(x, state)
+ sess.run([variables_lib.global_variables_initializer()])
+ res = sess.run([g, out_m], {
+ x.name: np.ones((batch_size, input_size)),
+ c.name: 0.1 * np.ones((batch_size, num_units)),
+ h.name: 0.1 * np.ones((batch_size, num_proj))
+ })
+ self.assertEqual(len(res), 2)
+ # The numbers in results were not calculated, this is mostly just a
+ # smoke test.
+ self.assertEqual(res[0].shape, (batch_size, num_proj))
+ self.assertEqual(res[1][0].shape, (batch_size, num_units))
+ self.assertEqual(res[1][1].shape, (batch_size, num_proj))
+ # Different inputs so different outputs and states
+ for i in range(1, batch_size):
+ self.assertTrue(
+ float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
+ self.assertTrue(
+ float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
+
def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(