aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager/python/network.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/eager/python/network.py')
-rw-r--r--tensorflow/contrib/eager/python/network.py63
1 files changed, 21 insertions, 42 deletions
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 1a5c6e8aec..c6e628b074 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -244,12 +244,6 @@ class Network(base.Layer):
self._owned_layers = {}
# The scope to use if we end up without a parent.
self._default_parent_variable_scope = variable_scope.get_variable_scope()
- # Hold on to the variable scope counts from init to check whether a scope
- # with the name we want was ever created in our parent scope. Without this
- # check we might have name collisions if the parent scope on init gets
- # closed before build is called.
- self._variable_scope_counts_on_init = (
- variable_scope._get_default_variable_store().variable_scopes_count)
self._custom_getter, self._deferred_restorations = (
_make_custom_getter_for_deferred_restorations())
@@ -267,29 +261,18 @@ class Network(base.Layer):
def _finalize_name(self, parent_network):
if not self._name:
+ if not parent_network:
+ name_uid_map = base._get_default_graph_uid_map()
+ else:
+ name_uid_map = parent_network._sub_layer_name_uids
# Were were not passed a name explicitly (or it was blank), so this is an
# anonymous Network. We make up a unique name.
if parent_network:
avoid_names = parent_network._owned_layers
- name_uid_map = parent_network._sub_layer_name_uids
else:
- name_uid_map = base._get_default_graph_uid_map()
- # Figure out which names we have to avoid based on which variable scope
- # we're nested in.
- strip_name = self._default_parent_variable_scope.name
- if strip_name:
- strip_name += "/"
- def _strip_on_init_scope(name):
- if name.startswith(strip_name):
- return name[len(strip_name):]
- else:
- return None
- avoid_names = set(
- _strip_on_init_scope(name)
- for name in self._variable_scope_counts_on_init.keys() if name)
+ avoid_names = None
self._name, self._base_name = self._make_unique_name(
- name_uid_map=name_uid_map, avoid_names=avoid_names,
- namespace=self._default_parent_variable_scope.name)
+ name_uid_map=name_uid_map, avoid_names=avoid_names)
if self._first_parent is None or (self._first_parent # False = no parent
and self._first_parent() is None):
# Save a pointer to the parent Network so that we can later check that the
@@ -319,13 +302,7 @@ class Network(base.Layer):
parent_scope = first_parent._scope
else:
parent_scope = self._default_parent_variable_scope
- with variable_scope.variable_scope(parent_scope) as parent_vs:
- expected_scope_name = parent_vs.name + "/" + self._name
- if expected_scope_name in self._variable_scope_counts_on_init:
- raise ValueError(
- ("A Network named '%s' already exists (or a variable_scope was "
- "created with this name). Names must be unique.") % (
- self._name,))
+ with variable_scope.variable_scope(parent_scope):
# Make sure variables with this prefix will be unique.
with variable_scope.variable_scope(
None, use_resource=True, default_name=self._name) as scope:
@@ -342,22 +319,25 @@ class Network(base.Layer):
"created with this name). Names must be unique.") % (
self._name,))
if (first_parent
- and scope_prefix[:-1] != first_parent.scope_name):
+ and scope_prefix[:-1] != first_parent._scope.name):
raise ValueError(
("Network variable names must match a nesting of sub-Network "
"names. Expected prefix '%s' from parent network, but got "
"'%s' when attempting to create a variable_scope for Network "
"'%s'. Likely an explicit variable_scope was inserted into "
"the nesting.") % (
- first_parent.scope_name,
+ first_parent._scope.name,
scope_prefix[:-1],
self._name))
elif not first_parent and scope_prefix:
# For the case when this Network is not nested inside any other
- # Network, but is in a variable_scope. This Network's name takes on
- # the full variable scope prefix.
- self._name = scope_name
-
+ # Network, but is in a variable_scope. This is an error for now.
+ raise ValueError(
+ "Creating Networks inside named variable_scopes is currently "
+ "not supported (to ensure that variable names match the names "
+ "of Networks in which they were first created). To set "
+ "options, try `with tf.variable_scope(''):`. If this "
+ "limitation bothers you, please file a feature request.")
for non_network_sublayer in self._non_network_sublayers:
self._set_scope_for_nonnetwork_sublayer(non_network_sublayer)
@@ -375,7 +355,8 @@ class Network(base.Layer):
raise ValueError(
("The parent of a Layer added to Network %s was garbage collected "
"before the Layer was built. If this limitation bothers you "
- "please file a feature request.") %
+ "please, comment on "
+ "https://github.com/tensorflow/tensorflow/issues/14164.") %
(self.name,))
with variable_scope.variable_scope(parent_scope):
# Horrid hack to make Layer variable names which are direct
@@ -439,9 +420,7 @@ class Network(base.Layer):
# name, and we should respect it (subject to error checking).
layer._name, layer._base_name = layer._make_unique_name(
name_uid_map=self._sub_layer_name_uids,
- avoid_names=self._owned_layers
- # No namespace required, since we've specified our own UID map.
- )
+ avoid_names=self._owned_layers)
layer._first_parent = weakref.ref(self)
self._non_network_sublayers.append(layer)
if (not layer.built
@@ -577,7 +556,7 @@ class Network(base.Layer):
if os.path.isdir(save_path):
# If we were passed a directory, default to naming based on the Network
# name.
- save_path = os.path.join(save_path, self.name.replace("/", "_"))
+ save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = _make_prefix_stripping_map_fn(self.scope_name)
@@ -771,7 +750,7 @@ class Network(base.Layer):
self._set_scope() # scope_name should be available to map_funcs
if os.path.isdir(save_path):
# If we don't have a name yet, set no parent.
- save_path = os.path.join(save_path, self.name.replace("/", "_"))
+ save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = _make_prefix_stripping_map_fn(self.scope_name)