diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2017-03-01 15:40:12 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-01 15:48:48 -0800 |
commit | a66d90285ef16c34818bd6f54569310174940fb0 (patch) | |
tree | 9084a9471724825668113ce60c6c2e3bcc5c3b6f | |
parent | 076cb52c34efe33771be0ea97940f201565d9390 (diff) |
VariableScope custom_getters now nest: a child VS with a new custom_getter will
call the parent scope's custom_getter instead of overriding it. The child-most
scope passes the "true" VariableScope getter all the way through to the
parent-most getter for the very innermost variable access.
Change: 148940536
-rw-r--r-- | tensorflow/contrib/layers/python/layers/layers.py | 5 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 46 | ||||
-rw-r--r-- | tensorflow/python/layers/base.py | 5 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 27 |
4 files changed, 73 insertions, 10 deletions
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 8fa734b089..3667f23697 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -1318,11 +1318,8 @@ def _model_variable_getter(getter, name, shape=None, dtype=None, def _build_variable_getter(rename=None): """Build a model variable getter that respects scope getter and renames.""" - # Respect current getter, if one is set. - current_custom_getter = variable_scope.get_variable_scope().custom_getter + # VariableScope will nest the getters def layer_variable_getter(getter, *args, **kwargs): - if current_custom_getter is not None: - getter = functools.partial(current_custom_getter, getter) kwargs['rename'] = rename return _model_variable_getter(getter, *args, **kwargs) return layer_variable_getter diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 751acfc854..fb27782562 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -921,6 +921,52 @@ class VariableScopeWithCustomGetterTest(test.TestCase): np_vars, np_v = sess.run([true_vars, v]) self.assertAllClose(np_v, sum(np_vars)) + def testNestedCustomGetters(self): + + def sum_getter(getter, name, *args, **kwargs): + g_0 = getter("%s/sum_0" % name, *args, **kwargs) + g_1 = getter("%s/sum_1" % name, *args, **kwargs) + with ops.name_scope("sum_getter"): + return g_0 + g_1 + + def prod_getter(getter, name, *args, **kwargs): + g_0 = getter("%s/prod_0" % name, *args, **kwargs) + g_1 = getter("%s/prod_1" % name, *args, **kwargs) + with ops.name_scope("prod_getter"): + return g_0 * g_1 + + with variable_scope.variable_scope( + "prod_scope", custom_getter=prod_getter): + with variable_scope.variable_scope( + "sum_scope", custom_getter=sum_getter): + with variable_scope.variable_scope( + "inner_sum_scope", custom_getter=sum_getter): + # take sums of sums of products + v = variable_scope.get_variable("v", [1, 2, 3]) + + self.assertEqual([1, 2, 3], v.get_shape()) + true_vars = variables_lib.trainable_variables() + self.assertEqual(8, len(true_vars)) + template = ( + "prod_scope/sum_scope/inner_sum_scope/v/sum_%d/sum_%d/prod_%d:0") + self.assertEqual(template % (0, 0, 0), true_vars[0].name) + self.assertEqual(template % (0, 0, 1), true_vars[1].name) + self.assertEqual(template % (0, 1, 0), true_vars[2].name) + self.assertEqual(template % (0, 1, 1), true_vars[3].name) + self.assertEqual(template % (1, 0, 0), true_vars[4].name) + self.assertEqual(template % (1, 0, 1), true_vars[5].name) + self.assertEqual(template % (1, 1, 0), true_vars[6].name) + self.assertEqual(template % (1, 1, 1), true_vars[7].name) + + with self.test_session() as sess: + variables_lib.global_variables_initializer().run() + np_vars, np_v = sess.run([true_vars, v]) + # take products of sums of products + self.assertAllClose( + np_v, + (((np_vars[0] * np_vars[1]) + (np_vars[2] * np_vars[3])) + + ((np_vars[4] * np_vars[5]) + (np_vars[6] * np_vars[7])))) + class PartitionInfoTest(test.TestCase): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index d2c7f12ba1..629c410bec 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -266,12 +266,9 @@ class _Layer(object): Output tensor(s). """ # Define a custom getter to override tf.get_variable when creating layer - # variables. We respect current custom getter, if one is set. - current_custom_getter = vs.get_variable_scope().custom_getter + # variables. The current custom getter is nested by the variable scope. def variable_getter(getter, name, shape, dtype=None, initializer=None, regularizer=None, trainable=True, **kwargs): - if current_custom_getter is not None: - getter = functools.partial(current_custom_getter, getter) return self._add_variable( name, shape, initializer=initializer, regularizer=regularizer, dtype=dtype, trainable=trainable, diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 1aef9dbaf2..8637d7513b 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1306,7 +1306,9 @@ def _pure_variable_scope(name_or_scope, if partitioner is not None: default_varscope[0].set_partitioner(partitioner) if custom_getter is not None: - default_varscope[0].set_custom_getter(custom_getter) + default_varscope[0].set_custom_getter( + _maybe_wrap_custom_getter( + custom_getter, name_or_scope.custom_getter)) if dtype is not None: default_varscope[0].set_dtype(dtype) if use_resource is not None: @@ -1337,7 +1339,8 @@ def _pure_variable_scope(name_or_scope, if partitioner is not None: default_varscope[0].set_partitioner(partitioner) if custom_getter is not None: - default_varscope[0].set_custom_getter(custom_getter) + default_varscope[0].set_custom_getter( + _maybe_wrap_custom_getter(custom_getter, old.custom_getter)) if dtype is not None: default_varscope[0].set_dtype(dtype) if use_resource is not None: @@ -1351,6 +1354,26 @@ def _pure_variable_scope(name_or_scope, default_varscope[0] = old +def _maybe_wrap_custom_getter(custom_getter, old_getter): + """Wrap a call to a custom_getter to use the old_getter internally.""" + if old_getter is None: + return custom_getter + + # The new custom_getter should call the old one + def wrapped_custom_getter(getter, *args, **kwargs): + # Call: + # custom_getter( + # lambda: old_getter(true_getter, ...), *args, **kwargs) + # which means custom_getter will call old_getter, which + # will call the true_getter, perform any intermediate + # processing, and return the results to the current + # getter, which will also perform additional processing. + return custom_getter( + functools.partial(old_getter, getter), + *args, **kwargs) + return wrapped_custom_getter + + def _get_unique_variable_scope(prefix): """Get a name with the given prefix unique in the current variable scope.""" var_store = _get_default_variable_store() |