aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py146
1 files changed, 146 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index fb91fe14f4..ebd4564f12 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -875,6 +875,152 @@ class RNNCellTest(test.TestCase):
self.assertAllClose(res[1].c, expected_state_c)
self.assertAllClose(res[1].h, expected_state_h)
+ def testConv1DLSTMCell(self):
+ with self.test_session() as sess:
+ shape = [2,1]
+ filter_size = [3]
+ num_features = 1
+ batch_size = 2
+ expected_state_c = np.array(
+ [[[1.4375670191], [1.4375670191]],
+ [[2.7542609292], [2.7542609292]]],
+ dtype=np.float32)
+ expected_state_h = np.array(
+ [[[0.6529865603], [0.6529865603]],
+ [[0.8736877431], [0.8736877431]]],
+ dtype=np.float32)
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(1.0/2.0)):
+ x = array_ops.placeholder(dtypes.float32, [None, None, 1])
+ cell = contrib_rnn_cell.Conv1DLSTMCell(input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
+ hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
+ output, state = cell(x, hidden)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state], {
+ hidden[0].name:
+ np.array([[[1.],[1.]],
+ [[2.],[2.]]]),
+ x.name:
+ np.array([[[1.],[1.]],
+ [[2.],[2.]]]),
+ })
+ # This is a smoke test, making sure expected values are unchanged.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], res[1].h)
+ self.assertAllClose(res[1].c, expected_state_c)
+ self.assertAllClose(res[1].h, expected_state_h)
+
+ def testConv2DLSTMCell(self):
+ with self.test_session() as sess:
+ shape = [2,2,1]
+ filter_size = [3,3]
+ num_features = 1
+ batch_size = 2
+ expected_state_c = np.array(
+ [[[[1.4375670191], [1.4375670191]],
+ [[1.4375670191], [1.4375670191]]],
+ [[[2.7542609292], [2.7542609292]],
+ [[2.7542609292], [2.7542609292]]]],
+ dtype=np.float32)
+ expected_state_h = np.array(
+ [[[[0.6529865603], [0.6529865603]],
+ [[0.6529865603], [0.6529865603]]],
+ [[[0.8736877431], [0.8736877431]],
+ [[0.8736877431], [0.8736877431]]]],
+ dtype=np.float32)
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(1.0/4.0)):
+ x = array_ops.placeholder(dtypes.float32, [None, None, None, 1])
+ cell = contrib_rnn_cell.Conv2DLSTMCell(input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
+ hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
+ output, state = cell(x, hidden)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state], {
+ hidden[0].name:
+ np.array([[[[1.],[1.]],
+ [[1.],[1.]]],
+ [[[2.],[2.]],
+ [[2.],[2.]]]]),
+ x.name:
+ np.array([[[[1.],[1.]],
+ [[1.],[1.]]],
+ [[[2.],[2.]],
+ [[2.],[2.]]]]),
+ })
+ # This is a smoke test, making sure expected values are unchanged.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], res[1].h)
+ self.assertAllClose(res[1].c, expected_state_c)
+ self.assertAllClose(res[1].h, expected_state_h)
+
+ def testConv3DLSTMCell(self):
+ with self.test_session() as sess:
+ shape = [2,2,2,1]
+ filter_size = [3,3,3]
+ num_features = 1
+ batch_size = 2
+ expected_state_c = np.array(
+ [[[[[1.4375670191], [1.4375670191]],
+ [[1.4375670191], [1.4375670191]]],
+ [[[1.4375670191], [1.4375670191]],
+ [[1.4375670191], [1.4375670191]]]],
+ [[[[2.7542609292], [2.7542609292]],
+ [[2.7542609292], [2.7542609292]]],
+ [[[2.7542609292], [2.7542609292]],
+ [[2.7542609292], [2.7542609292]]]]],
+ dtype=np.float32)
+ expected_state_h = np.array(
+ [[[[[0.6529865603], [0.6529865603]],
+ [[0.6529865603], [0.6529865603]]],
+ [[[0.6529865603], [0.6529865603]],
+ [[0.6529865603], [0.6529865603]]]],
+ [[[[0.8736877431], [0.8736877431]],
+ [[0.8736877431], [0.8736877431]]],
+ [[[0.8736877431], [0.8736877431]],
+ [[0.8736877431], [0.8736877431]]]]],
+ dtype=np.float32)
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(1.0/8.0)):
+ x = array_ops.placeholder(dtypes.float32, [None, None, None, None, 1])
+ cell = contrib_rnn_cell.Conv3DLSTMCell(input_shape=shape,
+ kernel_shape=filter_size,
+ output_channels=num_features)
+ hidden = cell.zero_state(array_ops.shape(x)[0], dtypes.float32)
+ output, state = cell(x, hidden)
+
+ sess.run([variables.global_variables_initializer()])
+ res = sess.run([output, state], {
+ hidden[0].name:
+ np.array([[[[[1.],[1.]],
+ [[1.],[1.]]],
+ [[[1.],[1.]],
+ [[1.],[1.]]]],
+ [[[[2.],[2.]],
+ [[2.],[2.]]],
+ [[[2.],[2.]],
+ [[2.],[2.]]]]]),
+ x.name:
+ np.array([[[[[1.],[1.]],
+ [[1.],[1.]]],
+ [[[1.],[1.]],
+ [[1.],[1.]]]],
+ [[[[2.],[2.]],
+ [[2.],[2.]]],
+ [[[2.],[2.]],
+ [[2.],[2.]]]]])
+ })
+ # This is a smoke test, making sure expected values are unchanged.
+ self.assertEqual(len(res), 2)
+ self.assertAllClose(res[0], res[1].h)
+ self.assertAllClose(res[1].c, expected_state_c)
+ self.assertAllClose(res[1].h, expected_state_h)
+
def testHighwayWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(