diff options
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index d41fc0b3ac..e512e8db53 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -483,7 +483,12 @@ class RNNCellTest(test.TestCase): base_cell = rnn_cell_impl.GRUCell(3) g, m_new = base_cell(x, m) variable_scope.get_variable_scope().reuse_variables() - g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m) + wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) + (name, dep), = wrapper_object._checkpoint_dependencies + self.assertIs(dep, base_cell) + self.assertEqual("cell", name) + + g_res, m_new_res = wrapper_object(x, m) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([g, g_res, m_new, m_new_res], { x: np.array([[1., 1., 1.]]), @@ -526,7 +531,12 @@ class RNNCellTest(test.TestCase): "root", initializer=init_ops.constant_initializer(0.5)): x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 3]) - cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159") + wrapped = rnn_cell_impl.GRUCell(3) + cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") + (name, dep), = cell._checkpoint_dependencies + self.assertIs(dep, wrapped) + self.assertEqual("cell", name) + outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) |