aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
index 33fd35c1a3..334baa5f9c 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py
@@ -904,6 +904,64 @@ class RNNCellTest(test.TestCase):
# States are left untouched
self.assertAllClose(res[2], res[3])
+ def testGLSTMCell(self):
+ # Ensure that G-LSTM matches LSTM when number_of_groups = 1
+ batch_size = 2
+ num_units = 4
+ number_of_groups = 1
+
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root1", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.ones([batch_size, num_units])
+ # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM
+ gcell = rnn_cell.GLSTMCell(num_units=num_units,
+ number_of_groups=number_of_groups)
+ cell = core_rnn_cell_impl.LSTMCell(num_units=num_units)
+ self.assertTrue(isinstance(gcell.state_size, tuple))
+ zero_state = gcell.zero_state(batch_size=batch_size,
+ dtype=dtypes.float32)
+ gh, gs = gcell(x, zero_state)
+ h, g = cell(x, zero_state)
+
+ sess.run([variables.global_variables_initializer()])
+ glstm_result = sess.run([gh, gs])
+ lstm_result = sess.run([h, g])
+
+ self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5)
+ self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5)
+
+ # Test that G-LSTM subgroup act like corresponding sub-LSTMs
+ batch_size = 2
+ num_units = 4
+ number_of_groups = 2
+
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root2", initializer=init_ops.constant_initializer(0.5)):
+ # input for G-LSTM with 2 groups
+ glstm_input = array_ops.ones([batch_size, num_units])
+ gcell = rnn_cell.GLSTMCell(num_units=num_units,
+ number_of_groups=number_of_groups)
+ gcell_zero_state = gcell.zero_state(batch_size=batch_size,
+ dtype=dtypes.float32)
+ gh, gs = gcell(glstm_input, gcell_zero_state)
+
+ # input for LSTM cell simulating single G-LSTM group
+ lstm_input = array_ops.ones([batch_size, num_units / number_of_groups])
+ # note division by number_of_groups. This cell one simulates G-LSTM group
+ cell = core_rnn_cell_impl.LSTMCell(num_units=
+ int(num_units / number_of_groups))
+ cell_zero_state = cell.zero_state(batch_size=batch_size,
+ dtype=dtypes.float32)
+ h, g = cell(lstm_input, cell_zero_state)
+
+ sess.run([variables.global_variables_initializer()])
+ [gh_res, h_res] = sess.run([gh, h])
+ self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)],
+ h_res, 1e-5)
+ self.assertAllClose(gh_res[:, int(num_units / number_of_groups):],
+ h_res, 1e-5)
class LayerNormBasicLSTMCellTest(test.TestCase):