aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/variable_scope_test.py
diff options
context:
space:
mode:
authorGravatar Lukasz Kaiser <lukaszkaiser@google.com>2017-03-26 19:59:45 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-26 21:13:49 -0700
commitc4330cad1e951ec02988e254a44244ca6fc2db69 (patch)
treefb3cfa076cd8d06cac7a3e0e7f5ee7eddedfe2f6 /tensorflow/python/kernel_tests/variable_scope_test.py
parent26be523ed4ab3a573af7771aec770832d9c30f90 (diff)
Add a test documenting custom_getter behaviour when reuse=True.
Change: 151282782
Diffstat (limited to 'tensorflow/python/kernel_tests/variable_scope_test.py')
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 22a4fe6a12..69d1a6f60e 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -951,6 +952,25 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
self.assertEqual(v3, v4)
self.assertEqual(3, called[0]) # skipped one in the first new_scope
+ def testCustomGetterWithReuse(self):
+ # Custom getter can choose to behave differently on reused variables.
+ def custom_getter(getter, *args, **kwargs):
+ var = getter(*args, **kwargs)
+ if kwargs["reuse"]:
+ # This can be used, e.g., for changing the caching device if needed.
+ return array_ops.identity(var, name="reused")
+ else:
+ return array_ops.identity(var, name="not_reused")
+
+ with variable_scope.variable_scope(
+ "scope", custom_getter=custom_getter) as scope:
+ v = variable_scope.get_variable("v", [1])
+ with variable_scope.variable_scope(scope, reuse=True):
+ v2 = variable_scope.get_variable("v", [1])
+
+ self.assertEqual(v.name, "not_reused:0")
+ self.assertEqual(v2.name, "reused:0")
+
def testGetterThatCreatesTwoVariablesAndSumsThem(self):
def custom_getter(getter, name, *args, **kwargs):