diff options
author | Francois Chollet <fchollet@google.com> | 2017-06-13 20:31:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-13 22:14:12 -0700 |
commit | 3df6e18b59b63ea4f5b68ba8c8ec878940a1ada1 (patch) | |
tree | 07f5e0f2bb3d3f5cee9562dc1eea5276c0dac942 /tensorflow/contrib/keras | |
parent | 6ffa51f1e0e76c87ee2164a8d421279768372501 (diff) |
Fix reset_uids Keras layers utility
PiperOrigin-RevId: 158935673
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r-- | tensorflow/contrib/keras/python/keras/backend.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/keras/python/keras/backend_test.py | 5 |
2 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index b7adf9461a..9f02fc0958 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -269,9 +269,10 @@ def get_uid(prefix=''): def reset_uids(): - layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS') - if layer_name_uids_collection: - layer_name_uids_collection.pop() + per_graph_layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS + keys = list(per_graph_layer_name_uids.keys()) + for key in keys: + del per_graph_layer_name_uids[key] def clear_session(): diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py index 2da5aee58e..a2bc95e4a1 100644 --- a/tensorflow/contrib/keras/python/keras/backend_test.py +++ b/tensorflow/contrib/keras/python/keras/backend_test.py @@ -105,10 +105,13 @@ class BackendUtilsTest(test.TestCase): self.assertEqual(keras.backend.image_data_format(), image_data_format) keras.backend.set_image_data_format('channels_last') - def test_get_uid(self): + def test_get_reset_uids(self): self.assertEqual(keras.backend.get_uid('foo'), 1) self.assertEqual(keras.backend.get_uid('foo'), 2) + keras.backend.reset_uids() + self.assertEqual(keras.backend.get_uid('foo'), 1) + class BackendVariableTest(test.TestCase): |