aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/slot_creator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r--tensorflow/python/training/slot_creator.py25
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py
index c631d78fdd..4d1ad44723 100644
--- a/tensorflow/python/training/slot_creator.py
+++ b/tensorflow/python/training/slot_creator.py
@@ -42,18 +42,29 @@ from __future__ import print_function
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
+from tensorflow.python.ops import variable_scope
def _create_slot_var(primary, val, scope):
"""Helper function for creating a slot variable."""
- slot = variables.Variable(val, name=scope, trainable=False)
+ # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
+ # scope.
+ current_partitioner = variable_scope.get_variable_scope().partitioner
+ variable_scope.get_variable_scope().set_partitioner(None)
+ slot = variable_scope.get_variable(scope, initializer=val, trainable=False)
+ variable_scope.get_variable_scope().set_partitioner(current_partitioner)
+
# pylint: disable=protected-access
if isinstance(primary, variables.Variable) and primary._save_slice_info:
# Primary is a partitioned variable, so we need to also indicate that
# the slot is a partitioned variable. Slots have the same partitioning
# as their primaries.
- real_slot_name = scope[len(primary.op.name + "/"):-1]
+ # For examples when using AdamOptimizer in linear model, slot.name
+ # here can be "linear//weights/Adam:0", while primary.op.name is
+ # "linear//weight". We want to get 'Adam' as real_slot_name, so we
+ # remove "'linear//weight' + '/'" and ':0'.
+ real_slot_name = slot.name[len(primary.op.name + "/"):-2]
slice_info = primary._save_slice_info
slot._set_save_slice_info(variables.Variable.SaveSliceInfo(
slice_info.full_name + "/" + real_slot_name,
@@ -80,12 +91,16 @@ def create_slot(primary, val, name, colocate_with_primary=True):
A `Variable` object.
"""
# Scope the slot name in the namespace of the primary variable.
- with ops.name_scope(primary.op.name + "/" + name) as scope:
+ # Set "primary.op.name + '/' + name" as default name, so the scope name of
+ # optimizer can be shared when reuse is True. Meanwhile when reuse is False
+ # and the same name has been previously used, the scope name will add '_N'
+ # as suffix for unique identifications.
+ with variable_scope.variable_scope(None, primary.op.name + '/' + name):
if colocate_with_primary:
with ops.colocate_with(primary):
- return _create_slot_var(primary, val, scope)
+ return _create_slot_var(primary, val, '')
else:
- return _create_slot_var(primary, val, scope)
+ return _create_slot_var(primary, val, '')
def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True):