aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py29
1 files changed, 27 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 06954f51d8..c14463bdad 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -210,7 +210,7 @@ class RNNCellTest(test.TestCase):
sess.run([variables_lib.global_variables_initializer()])
sess.run([g, out_m],
{x.name: 1 * np.ones([batch_size, input_size]),
- m.name: 0.1 * np.ones([batch_size - 1, state_size])})
+ m.name: 0.1 * np.ones([batch_size - 1, state_size])})
def testBasicLSTMCellStateSizeError(self):
"""Tests that state_size must be num_units * 2."""
@@ -218,7 +218,7 @@ class RNNCellTest(test.TestCase):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):
num_units = 2
- state_size = num_units * 3 # state_size must be num_units * 2
+ state_size = num_units * 3 # state_size must be num_units * 2
batch_size = 3
input_size = 4
x = array_ops.zeros([batch_size, input_size])
@@ -406,6 +406,31 @@ class RNNCellTest(test.TestCase):
# States are left untouched
self.assertAllClose(res[2], res[3])
+ def testResidualWrapperWithSlice(self):
+ with self.test_session() as sess:
+ with variable_scope.variable_scope(
+ "root", initializer=init_ops.constant_initializer(0.5)):
+ x = array_ops.zeros([1, 5])
+ m = array_ops.zeros([1, 3])
+ base_cell = rnn_cell_impl.GRUCell(3)
+ g, m_new = base_cell(x, m)
+ variable_scope.get_variable_scope().reuse_variables()
+ def residual_with_slice_fn(inp, out):
+ inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
+ return inp_sliced + out
+ g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
+ base_cell, residual_with_slice_fn)(x, m)
+ sess.run([variables_lib.global_variables_initializer()])
+ res_g, res_g_res, res_m_new, res_m_new_res = sess.run(
+ [g, g_res, m_new, m_new_res], {
+ x: np.array([[1., 1., 1., 1., 1.]]),
+ m: np.array([[0.1, 0.1, 0.1]])
+ })
+ # Residual connections
+ self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
+ # States are left untouched
+ self.assertAllClose(res_m_new, res_m_new_res)
+
def testDeviceWrapper(self):
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5)):