aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/layers/wrappers_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/layers/wrappers_test.py')
-rw-r--r--tensorflow/python/keras/layers/wrappers_test.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index 3f268acf5c..0cd774ef0f 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -87,6 +87,8 @@ class TimeDistributedTest(test.TestCase):
# test config
model.get_config()
+ # check whether the model variables are present in the
+ # checkpointable list of objects
checkpointed_objects = set(checkpointable_util.list_objects(model))
for v in model.variables:
self.assertIn(v, checkpointed_objects)
@@ -278,6 +280,12 @@ class BidirectionalTest(test.TestCase):
model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse')
model.fit(x, y, epochs=1, batch_size=1)
+ # check whether the model variables are present in the
+ # checkpointable list of objects
+ checkpointed_objects = set(checkpointable_util.list_objects(model))
+ for v in model.variables:
+ self.assertIn(v, checkpointed_objects)
+
# test compute output shape
ref_shape = model.layers[-1].output.get_shape()
shape = model.layers[-1].compute_output_shape(