diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 29 |
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index 06954f51d8..c14463bdad 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -210,7 +210,7 @@ class RNNCellTest(test.TestCase): sess.run([variables_lib.global_variables_initializer()]) sess.run([g, out_m], {x.name: 1 * np.ones([batch_size, input_size]), - m.name: 0.1 * np.ones([batch_size - 1, state_size])}) + m.name: 0.1 * np.ones([batch_size - 1, state_size])}) def testBasicLSTMCellStateSizeError(self): """Tests that state_size must be num_units * 2.""" @@ -218,7 +218,7 @@ class RNNCellTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): num_units = 2 - state_size = num_units * 3 # state_size must be num_units * 2 + state_size = num_units * 3 # state_size must be num_units * 2 batch_size = 3 input_size = 4 x = array_ops.zeros([batch_size, input_size]) @@ -406,6 +406,31 @@ class RNNCellTest(test.TestCase): # States are left untouched self.assertAllClose(res[2], res[3]) + def testResidualWrapperWithSlice(self): + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 5]) + m = array_ops.zeros([1, 3]) + base_cell = rnn_cell_impl.GRUCell(3) + g, m_new = base_cell(x, m) + variable_scope.get_variable_scope().reuse_variables() + def residual_with_slice_fn(inp, out): + inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3]) + return inp_sliced + out + g_res, m_new_res = rnn_cell_impl.ResidualWrapper( + base_cell, residual_with_slice_fn)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res_g, res_g_res, res_m_new, res_m_new_res = sess.run( + [g, g_res, m_new, m_new_res], { + x: np.array([[1., 1., 1., 1., 1.]]), + m: np.array([[0.1, 0.1, 0.1]]) + }) + # Residual connections + self.assertAllClose(res_g_res, res_g + [1., 1., 1.]) + # States are left untouched + self.assertAllClose(res_m_new, res_m_new_res) + def testDeviceWrapper(self): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(0.5)): |