diff options
Diffstat (limited to 'tensorflow/contrib/eager/python/network.py')
-rw-r--r-- | tensorflow/contrib/eager/python/network.py | 63 |
1 files changed, 21 insertions, 42 deletions
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 1a5c6e8aec..c6e628b074 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -244,12 +244,6 @@ class Network(base.Layer): self._owned_layers = {} # The scope to use if we end up without a parent. self._default_parent_variable_scope = variable_scope.get_variable_scope() - # Hold on to the variable scope counts from init to check whether a scope - # with the name we want was ever created in our parent scope. Without this - # check we might have name collisions if the parent scope on init gets - # closed before build is called. - self._variable_scope_counts_on_init = ( - variable_scope._get_default_variable_store().variable_scopes_count) self._custom_getter, self._deferred_restorations = ( _make_custom_getter_for_deferred_restorations()) @@ -267,29 +261,18 @@ class Network(base.Layer): def _finalize_name(self, parent_network): if not self._name: + if not parent_network: + name_uid_map = base._get_default_graph_uid_map() + else: + name_uid_map = parent_network._sub_layer_name_uids # Were were not passed a name explicitly (or it was blank), so this is an # anonymous Network. We make up a unique name. if parent_network: avoid_names = parent_network._owned_layers - name_uid_map = parent_network._sub_layer_name_uids else: - name_uid_map = base._get_default_graph_uid_map() - # Figure out which names we have to avoid based on which variable scope - # we're nested in. - strip_name = self._default_parent_variable_scope.name - if strip_name: - strip_name += "/" - def _strip_on_init_scope(name): - if name.startswith(strip_name): - return name[len(strip_name):] - else: - return None - avoid_names = set( - _strip_on_init_scope(name) - for name in self._variable_scope_counts_on_init.keys() if name) + avoid_names = None self._name, self._base_name = self._make_unique_name( - name_uid_map=name_uid_map, avoid_names=avoid_names, - namespace=self._default_parent_variable_scope.name) + name_uid_map=name_uid_map, avoid_names=avoid_names) if self._first_parent is None or (self._first_parent # False = no parent and self._first_parent() is None): # Save a pointer to the parent Network so that we can later check that the @@ -319,13 +302,7 @@ class Network(base.Layer): parent_scope = first_parent._scope else: parent_scope = self._default_parent_variable_scope - with variable_scope.variable_scope(parent_scope) as parent_vs: - expected_scope_name = parent_vs.name + "/" + self._name - if expected_scope_name in self._variable_scope_counts_on_init: - raise ValueError( - ("A Network named '%s' already exists (or a variable_scope was " - "created with this name). Names must be unique.") % ( - self._name,)) + with variable_scope.variable_scope(parent_scope): # Make sure variables with this prefix will be unique. with variable_scope.variable_scope( None, use_resource=True, default_name=self._name) as scope: @@ -342,22 +319,25 @@ class Network(base.Layer): "created with this name). Names must be unique.") % ( self._name,)) if (first_parent - and scope_prefix[:-1] != first_parent.scope_name): + and scope_prefix[:-1] != first_parent._scope.name): raise ValueError( ("Network variable names must match a nesting of sub-Network " "names. Expected prefix '%s' from parent network, but got " "'%s' when attempting to create a variable_scope for Network " "'%s'. Likely an explicit variable_scope was inserted into " "the nesting.") % ( - first_parent.scope_name, + first_parent._scope.name, scope_prefix[:-1], self._name)) elif not first_parent and scope_prefix: # For the case when this Network is not nested inside any other - # Network, but is in a variable_scope. This Network's name takes on - # the full variable scope prefix. - self._name = scope_name - + # Network, but is in a variable_scope. This is an error for now. + raise ValueError( + "Creating Networks inside named variable_scopes is currently " + "not supported (to ensure that variable names match the names " + "of Networks in which they were first created). To set " + "options, try `with tf.variable_scope(''):`. If this " + "limitation bothers you, please file a feature request.") for non_network_sublayer in self._non_network_sublayers: self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) @@ -375,7 +355,8 @@ class Network(base.Layer): raise ValueError( ("The parent of a Layer added to Network %s was garbage collected " "before the Layer was built. If this limitation bothers you " - "please file a feature request.") % + "please, comment on " + "https://github.com/tensorflow/tensorflow/issues/14164.") % (self.name,)) with variable_scope.variable_scope(parent_scope): # Horrid hack to make Layer variable names which are direct @@ -439,9 +420,7 @@ class Network(base.Layer): # name, and we should respect it (subject to error checking). layer._name, layer._base_name = layer._make_unique_name( name_uid_map=self._sub_layer_name_uids, - avoid_names=self._owned_layers - # No namespace required, since we've specified our own UID map. - ) + avoid_names=self._owned_layers) layer._first_parent = weakref.ref(self) self._non_network_sublayers.append(layer) if (not layer.built @@ -577,7 +556,7 @@ class Network(base.Layer): if os.path.isdir(save_path): # If we were passed a directory, default to naming based on the Network # name. - save_path = os.path.join(save_path, self.name.replace("/", "_")) + save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: map_func = _make_prefix_stripping_map_fn(self.scope_name) @@ -771,7 +750,7 @@ class Network(base.Layer): self._set_scope() # scope_name should be available to map_funcs if os.path.isdir(save_path): # If we don't have a name yet, set no parent. - save_path = os.path.join(save_path, self.name.replace("/", "_")) + save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: map_func = _make_prefix_stripping_map_fn(self.scope_name) |