aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/base.py')
-rw-r--r--tensorflow/python/layers/base.py6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 36491fc9c7..ebf108d8ca 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -27,6 +27,7 @@ import collections
import copy
import functools
import re
+import weakref
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
@@ -681,8 +682,7 @@ def _object_list_uid(object_list):
# A global dictionary mapping graph objects to an index of counters used
# for various layer names in each graph.
# Allows to give unique autogenerated names to layers, in a graph-specific way.
-PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict(
- lambda: collections.defaultdict(int))
+PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()
def _unique_layer_name(name):
@@ -704,6 +704,8 @@ def _unique_layer_name(name):
```
"""
graph = ops.get_default_graph()
+ if graph not in PER_GRAPH_LAYER_NAME_UIDS:
+ PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(int)
layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
layer_name_uids[name] += 1
return name + '_' + str(layer_name_uids[name])