diff options
author | Lukasz Kaiser <lukaszkaiser@google.com> | 2017-03-26 19:59:45 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-26 21:13:49 -0700 |
commit | c4330cad1e951ec02988e254a44244ca6fc2db69 (patch) | |
tree | fb3cfa076cd8d06cac7a3e0e7f5ee7eddedfe2f6 /tensorflow/python/kernel_tests/variable_scope_test.py | |
parent | 26be523ed4ab3a573af7771aec770832d9c30f90 (diff) |
Add a test documenting custom_getter behaviour when reuse=True.
Change: 151282782
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 22a4fe6a12..69d1a6f60e 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -951,6 +952,25 @@ class VariableScopeWithCustomGetterTest(test.TestCase): self.assertEqual(v3, v4) self.assertEqual(3, called[0]) # skipped one in the first new_scope + def testCustomGetterWithReuse(self): + # Custom getter can choose to behave differently on reused variables. + def custom_getter(getter, *args, **kwargs): + var = getter(*args, **kwargs) + if kwargs["reuse"]: + # This can be used, e.g., for changing the caching device if needed. + return array_ops.identity(var, name="reused") + else: + return array_ops.identity(var, name="not_reused") + + with variable_scope.variable_scope( + "scope", custom_getter=custom_getter) as scope: + v = variable_scope.get_variable("v", [1]) + with variable_scope.variable_scope(scope, reuse=True): + v2 = variable_scope.get_variable("v", [1]) + + self.assertEqual(v.name, "not_reused:0") + self.assertEqual(v2.name, "reused:0") + def testGetterThatCreatesTwoVariablesAndSumsThem(self): def custom_getter(getter, name, *args, **kwargs): |