diff options
author | Sergio Guadarrama <sguada@google.com> | 2017-11-28 15:31:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-28 15:35:23 -0800 |
commit | a6ee905de83834c35e7cf01182270309ec2425f3 (patch) | |
tree | 75782357442e6762f567c83327c32d5c5ce26d68 | |
parent | a99e9a2c56a4922e76c367b8d3a9c43ea0a4ef61 (diff) |
Add non_trainable_variables to templates.
Add aliases for weights, trainable_weights and non_trainable_weights.
PiperOrigin-RevId: 177228107
-rw-r--r-- | tensorflow/python/kernel_tests/template_test.py | 37 | ||||
-rw-r--r-- | tensorflow/python/ops/template.py | 37 |
2 files changed, 70 insertions, 4 deletions
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py index 40c0ade62a..f0354374ac 100644 --- a/tensorflow/python/kernel_tests/template_test.py +++ b/tensorflow/python/kernel_tests/template_test.py @@ -34,9 +34,10 @@ from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent -def variable_scoped_function(): +def variable_scoped_function(trainable=True): return variable_scope.get_variable( - "dummy", shape=[1], initializer=init_ops.zeros_initializer()) + "dummy", shape=[1], trainable=trainable, + initializer=init_ops.zeros_initializer()) def internally_variable_scoped_function(scope_name): @@ -413,7 +414,7 @@ class TemplateTest(test.TestCase): self.assertEqual(custom_getter_count[0], 2) # Test that custom getter is called when the variable scope is created - # during construction + # during construction custom_getter_count[0] = 0 tmpl2 = template.make_template( "s2", @@ -539,6 +540,36 @@ class TemplateTest(test.TestCase): # Ensure we can get the scopes before either template is actually called. self.assertEqual(1, len(ta.trainable_variables)) self.assertEqual(1, len(tb.trainable_variables)) + # None non-trainable variable was created. + self.assertEqual([], list(ta.non_trainable_variables)) + self.assertEqual([], list(tb.non_trainable_variables)) + # Ensure variables returns all the variables. + self.assertEqual(1, len(ta.variables)) + self.assertEqual(1, len(tb.variables)) + + @test_util.run_in_graph_and_eager_modes() + def test_non_trainable_variables(self): + # Make sure non_trainable_variables are created. + with variable_scope.variable_scope("foo2"): + ta = template.make_template("a", variable_scoped_function, + trainable=True) + tb = template.make_template("b", variable_scoped_function, + trainable=False) + # Initially there are not variables created. + self.assertEqual([], list(ta.variables)) + self.assertEqual([], list(tb.variables)) + # After calling there are variables created. + ta() + tb() + # Check the trainable and non_trainable variables. + self.assertEqual(1, len(ta.trainable_variables)) + self.assertEqual([], list(ta.non_trainable_variables)) + + self.assertEqual([], list(tb.trainable_variables)) + self.assertEqual(1, len(tb.non_trainable_variables)) + # Ensure variables returns all the variables. + self.assertEqual(1, len(ta.variables)) + self.assertEqual(1, len(tb.variables)) # TODO(apassos) handle local variables in Eager def test_local_variables(self): diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 98578b799a..07796b28d9 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -308,6 +308,12 @@ class Template(object): return name if name[-1] == "/" else name + "/" @property + def variables(self): + """Returns the list of global and local variables created by the Template. + """ + return self.global_variables + self.local_variables + + @property def trainable_variables(self): """Returns the list of trainable variables created by the Template.""" if self._variables_created: @@ -317,6 +323,14 @@ class Template(object): return [] @property + def non_trainable_variables(self): + """Returns the list of non-trainable variables created by the Template.""" + # TODO(apassos) Make sure it matches Eager when using local variables. + global_variables = self.global_variables + trainable_variables = set(self.trainable_variables) + return [x for x in global_variables if x not in trainable_variables] + + @property def global_variables(self): """Returns the list of global variables created by the Template.""" if self._variables_created: @@ -335,6 +349,21 @@ class Template(object): return [] @property + def weights(self): + """List of weights/variables created by the Template.""" + return self.variables + + @property + def trainable_weights(self): + """List of trainable weights/variables created by the Template.""" + return self.trainable_variables + + @property + def non_trainable_weights(self): + """List of non-trainable weights/variables created by the Template.""" + return self.non_trainable_variables + + @property @deprecated( "2017-02-21", "The .var_scope property is deprecated. Please change your " "code to use the .variable_scope property") @@ -501,7 +530,7 @@ class EagerTemplate(Template): @property def variables(self): - """Returns the list of trainable variables created by the Template.""" + """Returns the list of variables created by the Template.""" # Currently there is no local variable in Eager mode. return self._eager_variable_store.variables() @@ -512,6 +541,12 @@ class EagerTemplate(Template): return self._eager_variable_store.trainable_variables() @property + def non_trainable_variables(self): + """Returns the list of non-trainable variables created by the Template.""" + # Currently there is no local variable in Eager mode. + return self._eager_variable_store.non_trainable_variables() + + @property def global_variables(self): """Returns the list of global variables created by the Template.""" # Currently there is no local variable in Eager mode. |