diff options
-rw-r--r-- | tensorflow/contrib/keras/python/keras/backend.py | 14 | ||||
-rw-r--r-- | tensorflow/python/layers/base.py | 28 |
2 files changed, 19 insertions, 23 deletions
diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index 905ef13e14..ed2b251b31 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops @@ -261,16 +262,9 @@ def get_uid(prefix=''): 2 ``` """ - layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS') - if not layer_name_uids_collection: - layer_name_uids = {} - ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids) - else: - layer_name_uids = layer_name_uids_collection[0] - if prefix not in layer_name_uids: - layer_name_uids[prefix] = 1 - else: - layer_name_uids[prefix] += 1 + graph = ops.get_default_graph() + layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS[graph] + layer_name_uids[prefix] += 1 return layer_name_uids[prefix] diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 8410f12f3e..a37308f702 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -23,9 +23,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy import functools import re + from six.moves import xrange # pylint: disable=redefined-builtin import numpy as np import six @@ -650,10 +652,10 @@ def _to_list(x): return [x] -def _add_elements_to_collection(elements, collections): +def _add_elements_to_collection(elements, collection_list): elements = _to_list(elements) - collections = _to_list(collections) - for name in collections: + collection_list = _to_list(collection_list) + for name in collection_list: collection = ops.get_collection_ref(name) collection_set = set(collection) for element in elements: @@ -666,6 +668,13 @@ def _object_list_uid(object_list): return ', '.join([str(abs(id(x))) for x in object_list]) +# A global dictionary mapping graph objects to an index of counters used +# for various layer names in each graph. +# Allows to give unique autogenerated names to layers, in a graph-specific way. +PER_GRAPH_LAYER_NAME_UIDS = collections.defaultdict( + lambda: collections.defaultdict(int)) + + def _unique_layer_name(name): """Makes a layer name (or arbitrary string) unique within a TensorFlow graph. @@ -684,14 +693,7 @@ def _unique_layer_name(name): dense_2 ``` """ - layer_name_uids_collection = ops.get_collection('LAYER_NAME_UIDS') - if not layer_name_uids_collection: - layer_name_uids = {} - ops.add_to_collection('LAYER_NAME_UIDS', layer_name_uids) - else: - layer_name_uids = layer_name_uids_collection[0] - if name not in layer_name_uids: - layer_name_uids[name] = 1 - else: - layer_name_uids[name] += 1 + graph = ops.get_default_graph() + layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph] + layer_name_uids[name] += 1 return name + '_' + str(layer_name_uids[name]) |