aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index d41fc0b3ac..e512e8db53 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -483,7 +483,12 @@ class RNNCellTest(test.TestCase):
base_cell = rnn_cell_impl.GRUCell(3)
g, m_new = base_cell(x, m)
variable_scope.get_variable_scope().reuse_variables()
- g_res, m_new_res = rnn_cell_impl.ResidualWrapper(base_cell)(x, m)
+ wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell)
+ (name, dep), = wrapper_object._checkpoint_dependencies
+ self.assertIs(dep, base_cell)
+ self.assertEqual("cell", name)
+
+ g_res, m_new_res = wrapper_object(x, m)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run([g, g_res, m_new, m_new_res], {
x: np.array([[1., 1., 1.]]),
@@ -526,7 +531,12 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 3])
m = array_ops.zeros([1, 3])
- cell = rnn_cell_impl.DeviceWrapper(rnn_cell_impl.GRUCell(3), "/cpu:14159")
+ wrapped = rnn_cell_impl.GRUCell(3)
+ cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159")
+ (name, dep), = cell._checkpoint_dependencies
+ self.assertIs(dep, wrapped)
+ self.assertEqual("cell", name)
+
outputs, _ = cell(x, m)
self.assertTrue("cpu:14159" in outputs.device.lower())