diff options
Diffstat (limited to 'tensorflow/python/ops/variables.py')
-rw-r--r-- | tensorflow/python/ops/variables.py | 68 |
1 files changed, 55 insertions, 13 deletions
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index a706193a40..f0d2b8bf8c 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -1198,7 +1198,7 @@ class PartitionedVariable(object): "assign() has not been implemented for PartitionedVariable.") -def global_variables(): +def global_variables(scope=None): """Returns global variables. Global variables are variables that are shared across machines in a @@ -1210,10 +1210,17 @@ def global_variables(): An alternative to global variables are local variables. See @{tf.local_variables} + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of `Variable` objects. """ - return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) @deprecated("2017-03-02", "Please use tf.global_variables instead.") @@ -1222,18 +1229,25 @@ def all_variables(): return global_variables() -def _all_saveable_objects(): +def _all_saveable_objects(scope=None): """Returns all variables and `SaveableObject`s that must be checkpointed. + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of `Variable` and `SaveableObject` to be checkpointed """ # TODO(andreasst): make this function public once things are settled. - return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + - ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS)) + return (ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope) + + ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope)) -def local_variables(): +def local_variables(scope=None): """Returns local variables. Local variables - per process variables, usually not saved/restored to @@ -1247,22 +1261,36 @@ def local_variables(): An alternative to local variables are global variables. See @{tf.global_variables} + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of local `Variable` objects. """ - return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES) + return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope) -def model_variables(): +def model_variables(scope=None): """Returns all variables in the MODEL_VARIABLES collection. + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of local Variable objects. """ - return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES) + return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope) -def trainable_variables(): +def trainable_variables(scope=None): """Returns all variables created with `trainable=True`. When passed `trainable=True`, the `Variable()` constructor automatically @@ -1270,13 +1298,20 @@ def trainable_variables(): `GraphKeys.TRAINABLE_VARIABLES`. This convenience function returns the contents of that collection. + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of Variable objects. """ - return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope) -def moving_average_variables(): +def moving_average_variables(scope=None): """Returns all variables that maintain their moving averages. If an `ExponentialMovingAverage` object is created and the `apply()` @@ -1284,10 +1319,17 @@ def moving_average_variables(): be added to the `GraphKeys.MOVING_AVERAGE_VARIABLES` collection. This convenience function returns the contents of that collection. + Args: + scope: (Optional.) A string. If supplied, the resulting list is filtered + to include only items whose `name` attribute matches `scope` using + `re.match`. Items without a `name` attribute are never returned if a + scope is supplied. The choice of `re.match` means that a `scope` without + special tokens filters by prefix. + Returns: A list of Variable objects. """ - return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES) + return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope) def variables_initializer(var_list, name="init"): |