aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 73789206f3..70aaba1728 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -1549,5 +1549,100 @@ class BenchmarkLSTMCellXLA(test.Benchmark):
benchmark_results["wall_time"]]]))
+class WeightNormLSTMCellTest(test.TestCase):
+ """Compared cell output with pre-calculated values."""
+
+ def _cell_output(self, cell):
+ """Calculate cell output"""
+
+ with self.test_session() as sess:
+ init = init_ops.constant_initializer(0.5)
+ with variable_scope.variable_scope("root",
+ initializer=init):
+ x = array_ops.zeros([1, 2])
+ c0 = array_ops.zeros([1, 2])
+ h0 = array_ops.zeros([1, 2])
+
+ state0 = rnn_cell.LSTMStateTuple(c0, h0)
+
+ xout, sout = cell()(x, state0)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([xout, sout], {
+ x.name: np.array([[1., 1.]]),
+ c0.name: 0.1 * np.asarray([[0, 1]]),
+ h0.name: 0.1 * np.asarray([[2, 3]]),
+ })
+
+ actual_state_c = res[1].c
+ actual_state_h = res[1].h
+
+ return actual_state_c, actual_state_h
+
+ def testBasicCell(self):
+ """Tests cell w/o peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937078, 0.74983585]])
+ expected_h = np.array([[0.44923624, 0.49362513]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonbasicCell(self):
+ """Tests cell with peepholes and w/o normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=False,
+ use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.65937084, 0.7574988]])
+ expected_h = np.array([[0.4792085, 0.53470564]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+
+ def testBasicCellWithNorm(self):
+ """Tests cell w/o peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=False)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.58805949]])
+ expected_h = np.array([[0.32770363, 0.37397948]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
+ def testNonBasicCellWithNorm(self):
+ """Tests cell with peepholes and with normalisation"""
+
+ def cell():
+ return contrib_rnn_cell.WeightNormLSTMCell(2,
+ norm=True,
+ use_peepholes=True)
+
+ actual_c, actual_h = self._cell_output(cell)
+
+ expected_c = np.array([[0.50125383, 0.59587258]])
+ expected_h = np.array([[0.35041603, 0.40873795]])
+
+ self.assertAllClose(expected_c, actual_c, 1e-5)
+ self.assertAllClose(expected_h, actual_h, 1e-5)
+
if __name__ == "__main__":
test.main()