aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index f4589e3d9e..89ad0fcd75 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -194,6 +194,44 @@ class RNNCellTest(test.TestCase):
m.name: 0.1 * np.ones([1, 4])})
self.assertEqual(len(res), 2)
+ def testBasicLSTMCellDimension0Error(self):
+ """Tests that dimension 0 in both(x and m) shape must be equal."""
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ num_units = 2
+ state_size = num_units * 2
+ batch_size = 3
+ input_size = 4
+ x = array_ops.zeros([batch_size, input_size])
+ m = array_ops.zeros([batch_size - 1, state_size])
+ with self.assertRaises(ValueError):
+ g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+ num_units, state_is_tuple=False)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ sess.run([g, out_m],
+ {x.name: 1 * np.ones([batch_size, input_size]),
+ m.name: 0.1 * np.ones([batch_size - 1, state_size])})
+
+ def testBasicLSTMCellStateSizeError(self):
+ """Tests that state_size must be num_units * 2."""
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ num_units = 2
+ state_size = num_units * 3 # state_size must be num_units * 2
+ batch_size = 3
+ input_size = 4
+ x = array_ops.zeros([batch_size, input_size])
+ m = array_ops.zeros([batch_size, state_size])
+ with self.assertRaises(ValueError):
+ g, out_m = core_rnn_cell_impl.BasicLSTMCell(
+ num_units, state_is_tuple=False)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ sess.run([g, out_m],
+ {x.name: 1 * np.ones([batch_size, input_size]),
+ m.name: 0.1 * np.ones([batch_size, state_size])})
+
def testBasicLSTMCellStateTupleType(self):
with self.test_session():
with variable_scope.variable_scope(