diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index 33fd35c1a3..334baa5f9c 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -904,6 +904,64 @@ class RNNCellTest(test.TestCase): # States are left untouched self.assertAllClose(res[2], res[3]) + def testGLSTMCell(self): + # Ensure that G-LSTM matches LSTM when number_of_groups = 1 + batch_size = 2 + num_units = 4 + number_of_groups = 1 + + with self.test_session() as sess: + with variable_scope.variable_scope( + "root1", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.ones([batch_size, num_units]) + # When number_of_groups = 1, G-LSTM is equivalent to regular LSTM + gcell = rnn_cell.GLSTMCell(num_units=num_units, + number_of_groups=number_of_groups) + cell = core_rnn_cell_impl.LSTMCell(num_units=num_units) + self.assertTrue(isinstance(gcell.state_size, tuple)) + zero_state = gcell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + gh, gs = gcell(x, zero_state) + h, g = cell(x, zero_state) + + sess.run([variables.global_variables_initializer()]) + glstm_result = sess.run([gh, gs]) + lstm_result = sess.run([h, g]) + + self.assertAllClose(glstm_result[0], lstm_result[0], 1e-5) + self.assertAllClose(glstm_result[1], lstm_result[1], 1e-5) + + # Test that G-LSTM subgroup act like corresponding sub-LSTMs + batch_size = 2 + num_units = 4 + number_of_groups = 2 + + with self.test_session() as sess: + with variable_scope.variable_scope( + "root2", initializer=init_ops.constant_initializer(0.5)): + # input for G-LSTM with 2 groups + glstm_input = array_ops.ones([batch_size, num_units]) + gcell = rnn_cell.GLSTMCell(num_units=num_units, + number_of_groups=number_of_groups) + gcell_zero_state = gcell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + gh, gs = gcell(glstm_input, gcell_zero_state) + + # input for LSTM cell simulating single G-LSTM group + lstm_input = array_ops.ones([batch_size, num_units / number_of_groups]) + # note division by number_of_groups. This cell one simulates G-LSTM group + cell = core_rnn_cell_impl.LSTMCell(num_units= + int(num_units / number_of_groups)) + cell_zero_state = cell.zero_state(batch_size=batch_size, + dtype=dtypes.float32) + h, g = cell(lstm_input, cell_zero_state) + + sess.run([variables.global_variables_initializer()]) + [gh_res, h_res] = sess.run([gh, h]) + self.assertAllClose(gh_res[:, 0:int(num_units / number_of_groups)], + h_res, 1e-5) + self.assertAllClose(gh_res[:, int(num_units / number_of_groups):], + h_res, 1e-5) class LayerNormBasicLSTMCellTest(test.TestCase): |