aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sergio Guadarrama <sguada@google.com>2017-11-28 15:31:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-28 15:35:23 -0800
commita6ee905de83834c35e7cf01182270309ec2425f3 (patch)
tree75782357442e6762f567c83327c32d5c5ce26d68
parenta99e9a2c56a4922e76c367b8d3a9c43ea0a4ef61 (diff)
Add non_trainable_variables to templates.
Add aliases for weights, trainable_weights and non_trainable_weights. PiperOrigin-RevId: 177228107
-rw-r--r--tensorflow/python/kernel_tests/template_test.py37
-rw-r--r--tensorflow/python/ops/template.py37
2 files changed, 70 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 40c0ade62a..f0354374ac 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -34,9 +34,10 @@ from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
-def variable_scoped_function():
+def variable_scoped_function(trainable=True):
return variable_scope.get_variable(
- "dummy", shape=[1], initializer=init_ops.zeros_initializer())
+ "dummy", shape=[1], trainable=trainable,
+ initializer=init_ops.zeros_initializer())
def internally_variable_scoped_function(scope_name):
@@ -413,7 +414,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(custom_getter_count[0], 2)
# Test that custom getter is called when the variable scope is created
- # during construction
+ # during construction
custom_getter_count[0] = 0
tmpl2 = template.make_template(
"s2",
@@ -539,6 +540,36 @@ class TemplateTest(test.TestCase):
# Ensure we can get the scopes before either template is actually called.
self.assertEqual(1, len(ta.trainable_variables))
self.assertEqual(1, len(tb.trainable_variables))
+ # None non-trainable variable was created.
+ self.assertEqual([], list(ta.non_trainable_variables))
+ self.assertEqual([], list(tb.non_trainable_variables))
+ # Ensure variables returns all the variables.
+ self.assertEqual(1, len(ta.variables))
+ self.assertEqual(1, len(tb.variables))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_non_trainable_variables(self):
+ # Make sure non_trainable_variables are created.
+ with variable_scope.variable_scope("foo2"):
+ ta = template.make_template("a", variable_scoped_function,
+ trainable=True)
+ tb = template.make_template("b", variable_scoped_function,
+ trainable=False)
+ # Initially there are not variables created.
+ self.assertEqual([], list(ta.variables))
+ self.assertEqual([], list(tb.variables))
+ # After calling there are variables created.
+ ta()
+ tb()
+ # Check the trainable and non_trainable variables.
+ self.assertEqual(1, len(ta.trainable_variables))
+ self.assertEqual([], list(ta.non_trainable_variables))
+
+ self.assertEqual([], list(tb.trainable_variables))
+ self.assertEqual(1, len(tb.non_trainable_variables))
+ # Ensure variables returns all the variables.
+ self.assertEqual(1, len(ta.variables))
+ self.assertEqual(1, len(tb.variables))
# TODO(apassos) handle local variables in Eager
def test_local_variables(self):
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 98578b799a..07796b28d9 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -308,6 +308,12 @@ class Template(object):
return name if name[-1] == "/" else name + "/"
@property
+ def variables(self):
+ """Returns the list of global and local variables created by the Template.
+ """
+ return self.global_variables + self.local_variables
+
+ @property
def trainable_variables(self):
"""Returns the list of trainable variables created by the Template."""
if self._variables_created:
@@ -317,6 +323,14 @@ class Template(object):
return []
@property
+ def non_trainable_variables(self):
+ """Returns the list of non-trainable variables created by the Template."""
+ # TODO(apassos) Make sure it matches Eager when using local variables.
+ global_variables = self.global_variables
+ trainable_variables = set(self.trainable_variables)
+ return [x for x in global_variables if x not in trainable_variables]
+
+ @property
def global_variables(self):
"""Returns the list of global variables created by the Template."""
if self._variables_created:
@@ -335,6 +349,21 @@ class Template(object):
return []
@property
+ def weights(self):
+ """List of weights/variables created by the Template."""
+ return self.variables
+
+ @property
+ def trainable_weights(self):
+ """List of trainable weights/variables created by the Template."""
+ return self.trainable_variables
+
+ @property
+ def non_trainable_weights(self):
+ """List of non-trainable weights/variables created by the Template."""
+ return self.non_trainable_variables
+
+ @property
@deprecated(
"2017-02-21", "The .var_scope property is deprecated. Please change your "
"code to use the .variable_scope property")
@@ -501,7 +530,7 @@ class EagerTemplate(Template):
@property
def variables(self):
- """Returns the list of trainable variables created by the Template."""
+ """Returns the list of variables created by the Template."""
# Currently there is no local variable in Eager mode.
return self._eager_variable_store.variables()
@@ -512,6 +541,12 @@ class EagerTemplate(Template):
return self._eager_variable_store.trainable_variables()
@property
+ def non_trainable_variables(self):
+ """Returns the list of non-trainable variables created by the Template."""
+ # Currently there is no local variable in Eager mode.
+ return self._eager_variable_store.non_trainable_variables()
+
+ @property
def global_variables(self):
"""Returns the list of global variables created by the Template."""
# Currently there is no local variable in Eager mode.