aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 8637d7513b..1de95f1291 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -683,6 +683,9 @@ class _VariableStore(object):
init_val = initializer
variable_dtype = None
else:
+ # Instantiate initializer if provided initializer is a type object.
+ if isinstance(initializer, type(init_ops.Initializer)):
+ initializer = initializer(dtype=dtype)
init_val = lambda: initializer( # pylint: disable=g-long-lambda
shape.as_list(), dtype=dtype, partition_info=partition_info)
variable_dtype = dtype.base_dtype
@@ -881,6 +884,19 @@ class VariableScope(object):
"""Set custom getter for this scope."""
self._custom_getter = custom_getter
+ def get_collection(self, name):
+ """Get this scope's variables."""
+ scope = self._name + "/" if self._name else ""
+ return ops.get_collection(name, scope)
+
+ def trainable_variables(self):
+ """Get this scope's trainable variables."""
+ return self.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
+
+ def global_variables(self):
+ """Get this scope's global variables."""
+ return self.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
+
def get_variable(self,
var_store,
name,