aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/keras/python/keras/backend.py14
-rw-r--r--tensorflow/python/layers/base.py28
2 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py
index 905ef13e14..ed2b251b31 100644
--- a/tensorflow/contrib/keras/python/keras/backend.py
+++ b/tensorflow/contrib/keras/python/keras/backend.py
@@ -33,6 +33,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.layers import base as tf_base_layers
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
@@ -261,16 +262,9 @@ def get_uid(prefix=''):
2
```
"""
- layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
- if not layer_name_uids_collection:
- layer_name_uids = {}
- ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
- else:
- layer_name_uids = layer_name_uids_collection[0]
- if prefix not in layer_name_uids:
- layer_name_uids[prefix] = 1
- else:
- layer_name_uids[prefix] += 1
+ graph = ops.get_default_graph()
+ layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph]
+ layer_name_uids[prefix] += 1
return layer_name_uids[prefix]
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 8410f12f3e..a37308f702 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -23,9 +23,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import copy
import functools
import re
+
from six.moves import xrange # pylint: disable=redefined-builtin
import numpy as np
import six
@@ -650,10 +652,10 @@ def _to_list(x):
return [x]
-def _add_elements_to_collection(elements, collections):
+def _add_elements_to_collection(elements, collection_list):
elements = _to_list(elements)
- collections = _to_list(collections)
- for name in collections:
+ collection_list = _to_list(collection_list)
+ for name in collection_list:
collection = ops.get_collection_ref(name)
collection_set = set(collection)
for element in elements:
@@ -666,6 +668,13 @@ def _object_list_uid(object_list):
return ', '.join([str(abs(id(x))) for x in object_list])
+# A global dictionary mapping graph objects to an index of counters used
+# for various layer names in each graph.
+# Allows to give unique autogenerated names to layers, in a graph-specific way.
+PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict(
+ lambda: collections.defaultdict(int))
+
+
def _unique_layer_name(name):
"""Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
@@ -684,14 +693,7 @@ def _unique_layer_name(name):
dense_2
```
"""
- layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS')
- if not layer_name_uids_collection:
- layer_name_uids = {}
- ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids)
- else:
- layer_name_uids = layer_name_uids_collection[0]
- if name not in layer_name_uids:
- layer_name_uids[name] = 1
- else:
- layer_name_uids[name] += 1
+ graph = ops.get_default_graph()
+ layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
+ layer_name_uids[name] += 1
return name + '_' + str(layer_name_uids[name])