diff options
Diffstat (limited to 'tensorflow/python/keras/layers/wrappers_test.py')
-rw-r--r-- | tensorflow/python/keras/layers/wrappers_test.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 3f268acf5c..0cd774ef0f 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -87,6 +87,8 @@ class TimeDistributedTest(test.TestCase): # test config model.get_config() + # check whether the model variables are present in the + # checkpointable list of objects checkpointed_objects = set(checkpointable_util.list_objects(model)) for v in model.variables: self.assertIn(v, checkpointed_objects) @@ -278,6 +280,12 @@ class BidirectionalTest(test.TestCase): model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) + # check whether the model variables are present in the + # checkpointable list of objects + checkpointed_objects = set(checkpointable_util.list_objects(model)) + for v in model.variables: + self.assertIn(v, checkpointed_objects) + # test compute output shape ref_shape = model.layers[-1].output.get_shape() shape = model.layers[-1].compute_output_shape( |