aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/rnn
diff options
context:
space:
mode:
authorGravatar Allen Lavoie <allenl@google.com>2018-05-17 14:32:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-17 14:37:26 -0700
commit620bcf01283abc434b1971106863269168cb8a5a (patch)
tree6317560b74256003ae8fe21667da635c56505024 /tensorflow/contrib/rnn
parent5704164f0e7d6a040d8213fd5993cd2c64959fd7 (diff)
Basic usability fixes for RNNCell wrappers
They weren't calling their parent constructors (for the Keras base Layer), so a bunch of their methods threw odd errors. There may still be issues, but hopefully not so blatent. Fixes #19208. For real this time. PiperOrigin-RevId: 197052962
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r--tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index e512e8db53..b8840a8f24 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import functools
+import os
import numpy as np
@@ -30,6 +31,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -39,6 +41,7 @@ from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_utils
# pylint: enable=protected-access
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
@@ -189,6 +192,7 @@ class RNNCellTest(test.TestCase):
self.assertEqual(cell.dtype, None)
self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name)
self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name)
+ cell.get_config() # Should not throw an error
g, out_m = cell(x, m)
# Layer infers the input type.
self.assertEqual(cell.dtype, dtype.name)
@@ -439,6 +443,26 @@ class RNNCellTest(test.TestCase):
self.assertTrue(
float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6)
+ @test_util.run_in_graph_and_eager_modes()
+ def testWrapperCheckpointing(self):
+ for wrapper_type in [
+ rnn_cell_impl.DropoutWrapper,
+ rnn_cell_impl.ResidualWrapper,
+ lambda cell: rnn_cell_impl.MultiRNNCell([cell])]:
+ with self.test_session():
+ cell = rnn_cell_impl.BasicRNNCell(1)
+ wrapper = wrapper_type(cell)
+ wrapper(array_ops.ones([1, 1]),
+ state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32))
+ self.evaluate([v.initializer for v in cell.variables])
+ checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(cell._bias.assign([40.]))
+ save_path = checkpoint.save(prefix)
+ self.evaluate(cell._bias.assign([0.]))
+ checkpoint.restore(save_path).assert_consumed().run_restore_ops()
+ self.assertAllEqual([40.], self.evaluate(cell._bias))
+
def testOutputProjectionWrapper(self):
with self.test_session() as sess:
with variable_scope.variable_scope(
@@ -485,6 +509,7 @@ class RNNCellTest(test.TestCase):
variable_scope.get_variable_scope().reuse_variables()
wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell)
(name, dep), = wrapper_object._checkpoint_dependencies
+ wrapper_object.get_config() # Should not throw an error
self.assertIs(dep, base_cell)
self.assertEqual("cell", name)
@@ -534,6 +559,7 @@ class RNNCellTest(test.TestCase):
wrapped = rnn_cell_impl.GRUCell(3)
cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159")
(name, dep), = cell._checkpoint_dependencies
+ cell.get_config() # Should not throw an error
self.assertIs(dep, wrapped)
self.assertEqual("cell", name)