aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/recurrent_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/recurrent_test.py')
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index 802374d2d2..fefb92826b 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
+from tensorflow.python.training.checkpointable import util as checkpointable_util
class RNNTest(test.TestCase):
@@ -556,5 +557,22 @@ class RNNTest(test.TestCase):
[tuple(o.as_list()) for o in output_shape],
expected_output_shape)
+ def test_checkpointable_dependencies(self):
+ rnn = keras.layers.SimpleRNN
+ with self.test_session():
+ x = np.random.random((2, 2, 2))
+ y = np.random.random((2, 2))
+ model = keras.models.Sequential()
+ model.add(rnn(2))
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.fit(x, y, epochs=1, batch_size=1)
+
+ # check whether the model variables are present in the
+ # checkpointable list of objects
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
+
if __name__ == '__main__':
test.main()