diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 73789206f3..70aaba1728 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -1549,5 +1549,100 @@ class BenchmarkLSTMCellXLA(test.Benchmark): benchmark_results["wall_time"]]])) +class WeightNormLSTMCellTest(test.TestCase): + """Compared cell output with pre-calculated values.""" + + def _cell_output(self, cell): + """Calculate cell output""" + + with self.test_session() as sess: + init = init_ops.constant_initializer(0.5) + with variable_scope.variable_scope("root", + initializer=init): + x = array_ops.zeros([1, 2]) + c0 = array_ops.zeros([1, 2]) + h0 = array_ops.zeros([1, 2]) + + state0 = rnn_cell.LSTMStateTuple(c0, h0) + + xout, sout = cell()(x, state0) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([xout, sout], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + }) + + actual_state_c = res[1].c + actual_state_h = res[1].h + + return actual_state_c, actual_state_h + + def testBasicCell(self): + """Tests cell w/o peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937078, 0.74983585]]) + expected_h = np.array([[0.44923624, 0.49362513]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonbasicCell(self): + """Tests cell with peepholes and w/o normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=False, + use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.65937084, 0.7574988]]) + expected_h = np.array([[0.4792085, 0.53470564]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + + def testBasicCellWithNorm(self): + """Tests cell w/o peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=False) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.58805949]]) + expected_h = np.array([[0.32770363, 0.37397948]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + + def testNonBasicCellWithNorm(self): + """Tests cell with peepholes and with normalisation""" + + def cell(): + return contrib_rnn_cell.WeightNormLSTMCell(2, + norm=True, + use_peepholes=True) + + actual_c, actual_h = self._cell_output(cell) + + expected_c = np.array([[0.50125383, 0.59587258]]) + expected_h = np.array([[0.35041603, 0.40873795]]) + + self.assertAllClose(expected_c, actual_c, 1e-5) + self.assertAllClose(expected_h, actual_h, 1e-5) + if __name__ == "__main__": test.main() |