diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 146 |
1 files changed, 146 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index fb91fe14f4..ebd4564f12 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -875,6 +875,152 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[1].c, expected_state_c) self.assertAllClose(res[1].h, expected_state_h) + def testConv1DLSTMCell(self): + with self.test_session() as sess: + shape = [2,1] + filter_size = [3] + num_features = 1 + batch_size = 2 + expected_state_c = np.array( + [[[1.4375670191], [1.4375670191]], + [[2.7542609292], [2.7542609292]]], + dtype=np.float32) + expected_state_h = np.array( + [[[0.6529865603], [0.6529865603]], + [[0.8736877431], [0.8736877431]]], + dtype=np.float32) + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(1.0/2.0)): + x = array_ops.placeholder(dtypes.float32, [None, None, 1]) + cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) + hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) + output, state = cell(x, hidden) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state], { + hidden[0].name: + np.array([[[1.],[1.]], + [[2.],[2.]]]), + x.name: + np.array([[[1.],[1.]], + [[2.],[2.]]]), + }) + # This is a smoke test, making sure expected values are unchanged. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], res[1].h) + self.assertAllClose(res[1].c, expected_state_c) + self.assertAllClose(res[1].h, expected_state_h) + + def testConv2DLSTMCell(self): + with self.test_session() as sess: + shape = [2,2,1] + filter_size = [3,3] + num_features = 1 + batch_size = 2 + expected_state_c = np.array( + [[[[1.4375670191], [1.4375670191]], + [[1.4375670191], [1.4375670191]]], + [[[2.7542609292], [2.7542609292]], + [[2.7542609292], [2.7542609292]]]], + dtype=np.float32) + expected_state_h = np.array( + [[[[0.6529865603], [0.6529865603]], + [[0.6529865603], [0.6529865603]]], + [[[0.8736877431], [0.8736877431]], + [[0.8736877431], [0.8736877431]]]], + dtype=np.float32) + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(1.0/4.0)): + x = array_ops.placeholder(dtypes.float32, [None, None, None, 1]) + cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) + hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) + output, state = cell(x, hidden) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state], { + hidden[0].name: + np.array([[[[1.],[1.]], + [[1.],[1.]]], + [[[2.],[2.]], + [[2.],[2.]]]]), + x.name: + np.array([[[[1.],[1.]], + [[1.],[1.]]], + [[[2.],[2.]], + [[2.],[2.]]]]), + }) + # This is a smoke test, making sure expected values are unchanged. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], res[1].h) + self.assertAllClose(res[1].c, expected_state_c) + self.assertAllClose(res[1].h, expected_state_h) + + def testConv3DLSTMCell(self): + with self.test_session() as sess: + shape = [2,2,2,1] + filter_size = [3,3,3] + num_features = 1 + batch_size = 2 + expected_state_c = np.array( + [[[[[1.4375670191], [1.4375670191]], + [[1.4375670191], [1.4375670191]]], + [[[1.4375670191], [1.4375670191]], + [[1.4375670191], [1.4375670191]]]], + [[[[2.7542609292], [2.7542609292]], + [[2.7542609292], [2.7542609292]]], + [[[2.7542609292], [2.7542609292]], + [[2.7542609292], [2.7542609292]]]]], + dtype=np.float32) + expected_state_h = np.array( + [[[[[0.6529865603], [0.6529865603]], + [[0.6529865603], [0.6529865603]]], + [[[0.6529865603], [0.6529865603]], + [[0.6529865603], [0.6529865603]]]], + [[[[0.8736877431], [0.8736877431]], + [[0.8736877431], [0.8736877431]]], + [[[0.8736877431], [0.8736877431]], + [[0.8736877431], [0.8736877431]]]]], + dtype=np.float32) + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(1.0/8.0)): + x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1]) + cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape, + kernel_shape=filter_size, + output_channels=num_features) + hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32) + output, state = cell(x, hidden) + + sess.run([variables.global_variables_initializer()]) + res = sess.run([output, state], { + hidden[0].name: + np.array([[[[[1.],[1.]], + [[1.],[1.]]], + [[[1.],[1.]], + [[1.],[1.]]]], + [[[[2.],[2.]], + [[2.],[2.]]], + [[[2.],[2.]], + [[2.],[2.]]]]]), + x.name: + np.array([[[[[1.],[1.]], + [[1.],[1.]]], + [[[1.],[1.]], + [[1.],[1.]]]], + [[[[2.],[2.]], + [[2.],[2.]]], + [[[2.],[2.]], + [[2.],[2.]]]]]) + }) + # This is a smoke test, making sure expected values are unchanged. + self.assertEqual(len(res), 2) + self.assertAllClose(res[0], res[1].h) + self.assertAllClose(res[1].c, expected_state_c) + self.assertAllClose(res[1].h, expected_state_h) + def testHighwayWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( |