aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py44
1 files changed, 0 insertions, 44 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index b4a5f2d7eb..ebd4564f12 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -37,7 +37,6 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import rnn
from tensorflow.python.ops import rnn_cell
-from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -1276,49 +1275,6 @@ class LayerNormBasicLSTMCellTest(test.TestCase):
self.assertAllClose(res[2].c, expected_c1, 1e-5)
self.assertAllClose(res[2].h, expected_h1, 1e-5)
-
- def testBasicLSTMCellWithStateTupleLayerNorm(self):
- """The results of LSTMCell and LayerNormBasicLSTMCell
- should be same. """
- with self.test_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 2])
- c0 = array_ops.zeros([1, 2])
- h0 = array_ops.zeros([1, 2])
- state0 = rnn_cell_impl.LSTMStateTuple(c0, h0)
- c1 = array_ops.zeros([1, 2])
- h1 = array_ops.zeros([1, 2])
- state1 = rnn_cell_impl.LSTMStateTuple(c1, h1)
- cell = rnn_cell_impl.MultiRNNCell(
- [contrib_rnn_cell.LayerNormLSTMCell(
- 2,
- layer_norm=True,
- norm_gain=1.0,
- norm_shift=0.0) for _ in range(2)])
- h, (s0, s1) = cell(x, (state0, state1))
- sess.run([variables.global_variables_initializer()])
- res = sess.run([h, s0, s1], {
- x.name: np.array([[1., 1.]]),
- c0.name: 0.1 * np.asarray([[0, 1]]),
- h0.name: 0.1 * np.asarray([[2, 3]]),
- c1.name: 0.1 * np.asarray([[4, 5]]),
- h1.name: 0.1 * np.asarray([[6, 7]]),
- })
-
- expected_h = np.array([[-0.38079708, 0.38079708]])
- expected_h0 = np.array([[-0.38079708, 0.38079708]])
- expected_c0 = np.array([[-1.0, 1.0]])
- expected_h1 = np.array([[-0.38079708, 0.38079708]])
- expected_c1 = np.array([[-1.0, 1.0]])
-
- self.assertEqual(len(res), 3)
- self.assertAllClose(res[0], expected_h, 1e-5)
- self.assertAllClose(res[1].c, expected_c0, 1e-5)
- self.assertAllClose(res[1].h, expected_h0, 1e-5)
- self.assertAllClose(res[2].c, expected_c1, 1e-5)
- self.assertAllClose(res[2].h, expected_h1, 1e-5)
-
def testBasicLSTMCellWithDropout(self):
def _is_close(x, y, digits=4):