diff options
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent_test.py')
-rw-r--r-- | tensorflow/python/keras/layers/recurrent_test.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 802374d2d2..fefb92826b 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import test +from tensorflow.python.training.checkpointable import util as checkpointable_util class RNNTest(test.TestCase): @@ -556,5 +557,22 @@ class RNNTest(test.TestCase): [tuple(o.as_list()) for o in output_shape], expected_output_shape) + def test_checkpointable_dependencies(self): + rnn = keras.layers.SimpleRNN + with self.test_session(): + x = np.random.random((2, 2, 2)) + y = np.random.random((2, 2)) + model = keras.models.Sequential() + model.add(rnn(2)) + model.compile(optimizer='rmsprop', loss='mse') + model.fit(x, y, epochs=1, batch_size=1) + + # check whether the model variables are present in the + # checkpointable list of objects + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + + if __name__ == '__main__': test.main() |