aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2017-03-01 15:40:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-01 15:48:48 -0800
commita66d90285ef16c34818bd6f54569310174940fb0 (patch)
tree9084a9471724825668113ce60c6c2e3bcc5c3b6f
parent076cb52c34efe33771be0ea97940f201565d9390 (diff)
VariableScope custom_getters now nest: a child VS with a new custom_getter will
call the parent scope's custom_getter instead of overriding it. The child-most scope passes the "true" VariableScope getter all the way through to the parent-most getter for the very innermost variable access. Change: 148940536
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py5
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py46
-rw-r--r--tensorflow/python/layers/base.py5
-rw-r--r--tensorflow/python/ops/variable_scope.py27
4 files changed, 73 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 8fa734b089..3667f23697 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -1318,11 +1318,8 @@ def _model_variable_getter(getter, name, shape=None, dtype=None,
def _build_variable_getter(rename=None):
"""Build a model variable getter that respects scope getter and renames."""
- # Respect current getter, if one is set.
- current_custom_getter = variable_scope.get_variable_scope().custom_getter
+ # VariableScope will nest the getters
def layer_variable_getter(getter, *args, **kwargs):
- if current_custom_getter is not None:
- getter = functools.partial(current_custom_getter, getter)
kwargs['rename'] = rename
return _model_variable_getter(getter, *args, **kwargs)
return layer_variable_getter
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 751acfc854..fb27782562 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -921,6 +921,52 @@ class VariableScopeWithCustomGetterTest(test.TestCase):
np_vars, np_v = sess.run([true_vars, v])
self.assertAllClose(np_v, sum(np_vars))
+ def testNestedCustomGetters(self):
+
+ def sum_getter(getter, name, *args, **kwargs):
+ g_0 = getter("%s/sum_0" % name, *args, **kwargs)
+ g_1 = getter("%s/sum_1" % name, *args, **kwargs)
+ with ops.name_scope("sum_getter"):
+ return g_0 + g_1
+
+ def prod_getter(getter, name, *args, **kwargs):
+ g_0 = getter("%s/prod_0" % name, *args, **kwargs)
+ g_1 = getter("%s/prod_1" % name, *args, **kwargs)
+ with ops.name_scope("prod_getter"):
+ return g_0 * g_1
+
+ with variable_scope.variable_scope(
+ "prod_scope", custom_getter=prod_getter):
+ with variable_scope.variable_scope(
+ "sum_scope", custom_getter=sum_getter):
+ with variable_scope.variable_scope(
+ "inner_sum_scope", custom_getter=sum_getter):
+ # take sums of sums of products
+ v = variable_scope.get_variable("v", [1, 2, 3])
+
+ self.assertEqual([1, 2, 3], v.get_shape())
+ true_vars = variables_lib.trainable_variables()
+ self.assertEqual(8, len(true_vars))
+ template = (
+ "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0")
+ self.assertEqual(template % (0, 0, 0), true_vars[0].name)
+ self.assertEqual(template % (0, 0, 1), true_vars[1].name)
+ self.assertEqual(template % (0, 1, 0), true_vars[2].name)
+ self.assertEqual(template % (0, 1, 1), true_vars[3].name)
+ self.assertEqual(template % (1, 0, 0), true_vars[4].name)
+ self.assertEqual(template % (1, 0, 1), true_vars[5].name)
+ self.assertEqual(template % (1, 1, 0), true_vars[6].name)
+ self.assertEqual(template % (1, 1, 1), true_vars[7].name)
+
+ with self.test_session() as sess:
+ variables_lib.global_variables_initializer().run()
+ np_vars, np_v = sess.run([true_vars, v])
+ # take products of sums of products
+ self.assertAllClose(
+ np_v,
+ (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3]))
+ + ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7]))))
+
class PartitionInfoTest(test.TestCase):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index d2c7f12ba1..629c410bec 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -266,12 +266,9 @@ class _Layer(object):
Output tensor(s).
"""
# Define a custom getter to override tf.get_variable when creating layer
- # variables. We respect current custom getter, if one is set.
- current_custom_getter = vs.get_variable_scope().custom_getter
+ # variables. The current custom getter is nested by the variable scope.
def variable_getter(getter, name, shape, dtype=None, initializer=None,
regularizer=None, trainable=True, **kwargs):
- if current_custom_getter is not None:
- getter = functools.partial(current_custom_getter, getter)
return self._add_variable(
name, shape, initializer=initializer, regularizer=regularizer,
dtype=dtype, trainable=trainable,
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 1aef9dbaf2..8637d7513b 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1306,7 +1306,9 @@ def _pure_variable_scope(name_or_scope,
if partitioner is not None:
default_varscope[0].set_partitioner(partitioner)
if custom_getter is not None:
- default_varscope[0].set_custom_getter(custom_getter)
+ default_varscope[0].set_custom_getter(
+ _maybe_wrap_custom_getter(
+ custom_getter, name_or_scope.custom_getter))
if dtype is not None:
default_varscope[0].set_dtype(dtype)
if use_resource is not None:
@@ -1337,7 +1339,8 @@ def _pure_variable_scope(name_or_scope,
if partitioner is not None:
default_varscope[0].set_partitioner(partitioner)
if custom_getter is not None:
- default_varscope[0].set_custom_getter(custom_getter)
+ default_varscope[0].set_custom_getter(
+ _maybe_wrap_custom_getter(custom_getter, old.custom_getter))
if dtype is not None:
default_varscope[0].set_dtype(dtype)
if use_resource is not None:
@@ -1351,6 +1354,26 @@ def _pure_variable_scope(name_or_scope,
default_varscope[0] = old
+def _maybe_wrap_custom_getter(custom_getter, old_getter):
+ """Wrap a call to a custom_getter to use the old_getter internally."""
+ if old_getter is None:
+ return custom_getter
+
+ # The new custom_getter should call the old one
+ def wrapped_custom_getter(getter, *args, **kwargs):
+ # Call:
+ # custom_getter(
+ # lambda: old_getter(true_getter, ...), *args, **kwargs)
+ # which means custom_getter will call old_getter, which
+ # will call the true_getter, perform any intermediate
+ # processing, and return the results to the current
+ # getter, which will also perform additional processing.
+ return custom_getter(
+ functools.partial(old_getter, getter),
+ *args, **kwargs)
+ return wrapped_custom_getter
+
+
def _get_unique_variable_scope(prefix):
"""Get a name with the given prefix unique in the current variable scope."""
var_store = _get_default_variable_store()