aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py42
1 files changed, 0 insertions, 42 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 16b6d145e3..909c6aba2b 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -38,9 +38,6 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
-from tensorflow.python.framework import test_util
-from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell
-
# pylint: enable=protected-access
@@ -361,45 +358,6 @@ class RNNCellTest(test.TestCase):
self.assertEquals(variables[2].op.name,
"root/lstm_cell/projection/kernel")
- def testLSTMCellLayerNorm(self):
- with self.test_session() as sess:
- num_units = 2
- num_proj = 3
- batch_size = 1
- input_size = 4
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([batch_size, input_size])
- c = array_ops.zeros([batch_size, num_units])
- h = array_ops.zeros([batch_size, num_proj])
- state = rnn_cell_impl.LSTMStateTuple(c, h)
- cell = contrib_rnn_cell.LayerNormLSTMCell(
- num_units=num_units,
- num_proj=num_proj,
- forget_bias=1.0,
- layer_norm=True,
- norm_gain=1.0,
- norm_shift=0.0)
- g, out_m = cell(x, state)
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g, out_m], {
- x.name: np.ones((batch_size, input_size)),
- c.name: 0.1 * np.ones((batch_size, num_units)),
- h.name: 0.1 * np.ones((batch_size, num_proj))
- })
- self.assertEqual(len(res), 2)
- # The numbers in results were not calculated, this is mostly just a
- # smoke test.
- self.assertEqual(res[0].shape, (batch_size, num_proj))
- self.assertEqual(res[1][0].shape, (batch_size, num_units))
- self.assertEqual(res[1][1].shape, (batch_size, num_proj))
- # Different inputs so different outputs and states
- for i in range(1, batch_size):
- self.assertTrue(
- float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6)
- self.assertTrue(
- float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
-
def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(