diff options
Diffstat (limited to 'tensorflow/python/training/slot_creator.py')
-rw-r--r-- | tensorflow/python/training/slot_creator.py | 25 |
1 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index c631d78fdd..4d1ad44723 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -42,18 +42,29 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables +from tensorflow.python.ops import variable_scope def _create_slot_var(primary, val, scope): """Helper function for creating a slot variable.""" - slot = variables.Variable(val, name=scope, trainable=False) + # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current + # scope. + current_partitioner = variable_scope.get_variable_scope().partitioner + variable_scope.get_variable_scope().set_partitioner(None) + slot = variable_scope.get_variable(scope, initializer=val, trainable=False) + variable_scope.get_variable_scope().set_partitioner(current_partitioner) + # pylint: disable=protected-access if isinstance(primary, variables.Variable) and primary._save_slice_info: # Primary is a partitioned variable, so we need to also indicate that # the slot is a partitioned variable. Slots have the same partitioning # as their primaries. - real_slot_name = scope[len(primary.op.name + "/"):-1] + # For examples when using AdamOptimizer in linear model, slot.name + # here can be "linear//weights/Adam:0", while primary.op.name is + # "linear//weight". We want to get 'Adam' as real_slot_name, so we + # remove "'linear//weight' + '/'" and ':0'. + real_slot_name = slot.name[len(primary.op.name + "/"):-2] slice_info = primary._save_slice_info slot._set_save_slice_info(variables.Variable.SaveSliceInfo( slice_info.full_name + "/" + real_slot_name, @@ -80,12 +91,16 @@ def create_slot(primary, val, name, colocate_with_primary=True): A `Variable` object. """ # Scope the slot name in the namespace of the primary variable. - with ops.name_scope(primary.op.name + "/" + name) as scope: + # Set "primary.op.name + '/' + name" as default name, so the scope name of + # optimizer can be shared when reuse is True. Meanwhile when reuse is False + # and the same name has been previously used, the scope name will add '_N' + # as suffix for unique identifications. + with variable_scope.variable_scope(None, primary.op.name + '/' + name): if colocate_with_primary: with ops.colocate_with(primary): - return _create_slot_var(primary, val, scope) + return _create_slot_var(primary, val, '') else: - return _create_slot_var(primary, val, scope) + return _create_slot_var(primary, val, '') def create_zeros_slot(primary, name, dtype=None, colocate_with_primary=True): |