diff options
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 8637d7513b..1de95f1291 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -683,6 +683,9 @@ class _VariableStore(object): init_val = initializer variable_dtype = None else: + # Instantiate initializer if provided initializer is a type object. + if isinstance(initializer, type(init_ops.Initializer)): + initializer = initializer(dtype=dtype) init_val = lambda: initializer( # pylint: disable=g-long-lambda shape.as_list(), dtype=dtype, partition_info=partition_info) variable_dtype = dtype.base_dtype @@ -881,6 +884,19 @@ class VariableScope(object): """Set custom getter for this scope.""" self._custom_getter = custom_getter + def get_collection(self, name): + """Get this scope's variables.""" + scope = self._name + "/" if self._name else "" + return ops.get_collection(name, scope) + + def trainable_variables(self): + """Get this scope's trainable variables.""" + return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + + def global_variables(self): + """Get this scope's global variables.""" + return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + def get_variable(self, var_store, name, |