diff options
author | Allen Lavoie <allenl@google.com> | 2018-05-17 14:32:47 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-17 14:37:26 -0700 |
commit | 620bcf01283abc434b1971106863269168cb8a5a (patch) | |
tree | 6317560b74256003ae8fe21667da635c56505024 /tensorflow/contrib/rnn | |
parent | 5704164f0e7d6a040d8213fd5993cd2c64959fd7 (diff) |
Basic usability fixes for RNNCell wrappers
They weren't calling their parent constructors (for the Keras base Layer), so a bunch of their methods threw odd errors. There may still be issues, but hopefully not so blatent.
Fixes #19208. For real this time.
PiperOrigin-RevId: 197052962
Diffstat (limited to 'tensorflow/contrib/rnn')
-rw-r--r-- | tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index e512e8db53..b8840a8f24 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import functools +import os import numpy as np @@ -30,6 +31,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -39,6 +41,7 @@ from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_utils # pylint: enable=protected-access Linear = core_rnn_cell._Linear # pylint: disable=invalid-name @@ -189,6 +192,7 @@ class RNNCellTest(test.TestCase): self.assertEqual(cell.dtype, None) self.assertEqual("cell-0", cell._checkpoint_dependencies[0].name) self.assertEqual("cell-1", cell._checkpoint_dependencies[1].name) + cell.get_config() # Should not throw an error g, out_m = cell(x, m) # Layer infers the input type. self.assertEqual(cell.dtype, dtype.name) @@ -439,6 +443,26 @@ class RNNCellTest(test.TestCase): self.assertTrue( float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + @test_util.run_in_graph_and_eager_modes() + def testWrapperCheckpointing(self): + for wrapper_type in [ + rnn_cell_impl.DropoutWrapper, + rnn_cell_impl.ResidualWrapper, + lambda cell: rnn_cell_impl.MultiRNNCell([cell])]: + with self.test_session(): + cell = rnn_cell_impl.BasicRNNCell(1) + wrapper = wrapper_type(cell) + wrapper(array_ops.ones([1, 1]), + state=wrapper.zero_state(batch_size=1, dtype=dtypes.float32)) + self.evaluate([v.initializer for v in cell.variables]) + checkpoint = checkpointable_utils.Checkpoint(wrapper=wrapper) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(cell._bias.assign([40.])) + save_path = checkpoint.save(prefix) + self.evaluate(cell._bias.assign([0.])) + checkpoint.restore(save_path).assert_consumed().run_restore_ops() + self.assertAllEqual([40.], self.evaluate(cell._bias)) + def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -485,6 +509,7 @@ class RNNCellTest(test.TestCase): variable_scope.get_variable_scope().reuse_variables() wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell) (name, dep), = wrapper_object._checkpoint_dependencies + wrapper_object.get_config() # Should not throw an error self.assertIs(dep, base_cell) self.assertEqual("cell", name) @@ -534,6 +559,7 @@ class RNNCellTest(test.TestCase): wrapped = rnn_cell_impl.GRUCell(3) cell = rnn_cell_impl.DeviceWrapper(wrapped, "/cpu:14159") (name, dep), = cell._checkpoint_dependencies + cell.get_config() # Should not throw an error self.assertIs(dep, wrapped) self.assertEqual("cell", name) |