aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/keras
diff options
context:
space:
mode:
authorGravatar Francois Chollet <fchollet@google.com>2017-06-13 20:31:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-13 22:14:12 -0700
commit3df6e18b59b63ea4f5b68ba8c8ec878940a1ada1 (patch)
tree07f5e0f2bb3d3f5cee9562dc1eea5276c0dac942 /tensorflow/contrib/keras
parent6ffa51f1e0e76c87ee2164a8d421279768372501 (diff)
Fix reset_uids Keras layers utility
PiperOrigin-RevId: 158935673
Diffstat (limited to 'tensorflow/contrib/keras')
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py7
-rw-r--r--tensorflow/contrib/keras/python/keras/backend_test.py5
2 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index b7adf9461a..9f02fc0958 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -269,9 +269,10 @@ def get_uid(prefix=''):
def reset_uids():
- layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS')
- if layer_name_uids_collection:
- layer_name_uids_collection.pop()
+ per_graph_layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS
+ keys = list(per_graph_layer_name_uids.keys())
+ for key in keys:
+ del per_graph_layer_name_uids[key]
def clear_session():
diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py
index 2da5aee58e..a2bc95e4a1 100644
--- a/tensorflow/contrib/keras/python/keras/backend_test.py
+++ b/tensorflow/contrib/keras/python/keras/backend_test.py
@@ -105,10 +105,13 @@ class BackendUtilsTest(test.TestCase):
self.assertEqual(keras.backend.image_data_format(), image_data_format)
keras.backend.set_image_data_format('channels_last')
- def test_get_uid(self):
+ def test_get_reset_uids(self):
self.assertEqual(keras.backend.get_uid('foo'), 1)
self.assertEqual(keras.backend.get_uid('foo'), 2)
+ keras.backend.reset_uids()
+ self.assertEqual(keras.backend.get_uid('foo'), 1)
+
class BackendVariableTest(test.TestCase):