From e4a8dc831dbf2894c79659d50aea73999c1ff173 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 31 Mar 2017 02:30:28 -0800 Subject: Allow variable reuse in function. Change: 151807504 --- tensorflow/python/framework/function.py | 7 ++++ tensorflow/python/framework/function_test.py | 52 ++++++++++++++++++++++++++++ tensorflow/python/ops/variable_scope.py | 5 ++- 3 files changed, 63 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 3c663c3a9b..8b156db6dc 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -299,6 +299,7 @@ class _FuncGraph(ops.Graph): shape=None, dtype=None, initializer=None, + reuse=None, trainable=True, collections=None, # pylint: disable=redefined-outer-name use_resource=None, @@ -319,6 +320,7 @@ class _FuncGraph(ops.Graph): shape=shape, dtype=dtype, initializer=initializer, + reuse=reuse, trainable=trainable, collections=collections, use_resource=use_resource) @@ -886,6 +888,11 @@ class Defun(object): default graph. Because the addition of the function into the graph is deferred, the decorator can be used anywhere in the program. + Any variables created inside of the function are hoisted into the outer graph. + Note that the variables are created in the variable scope that was active + during the first call to the function. Subsequent function calls will refer to + the same set of variables. + Definitions of functions are frozen in a graph as soon as the graph is used to create a session. Therefore, nodes using the function must be created in the graph before the corresponding session is created. diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 96bf7bde29..51a9215ac4 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -722,6 +722,58 @@ class FunctionTest(test.TestCase): y = Bar(array_ops.zeros([1, 2, 3])) self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3]) + def testVariableReuse(self): + def LinearWithReuse(input_tensor, reuse=None): + size = input_tensor.shape.dims[1] + with variable_scope.variable_scope("linear", reuse=reuse): + w = variable_scope.get_variable("w", shape=[size, size], + dtype=input_tensor.dtype) + return math_ops.matmul(input_tensor, w) + + @function.Defun(dtypes.float32) + def Foo(inputs): + inputs = array_ops.reshape(inputs, [32, 100]) + hidden = LinearWithReuse(inputs) + return LinearWithReuse(hidden, reuse=True) + + input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32) + output_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "linear/w:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + output_val = sess.run(output_op, + feed_dict={input_op: np.random.rand(32, 100)}) + self.assertEqual(output_val.shape, (32, 100)) + + def testFunctionCallInDifferentVariableScopes(self): + @function.Defun(dtypes.float32) + def Foo(inputs): + var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32, + initializer=init_ops.ones_initializer()) + return inputs + var + + input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32) + with variable_scope.variable_scope("vs1"): + out1_op = Foo(input_op) + + with variable_scope.variable_scope("vs2"): + out2_op = Foo(input_op) + + global_vars = variables.global_variables() + self.assertEqual(len(global_vars), 1) + self.assertEqual(global_vars[0].name, "vs1/var:0") + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + out1, out2 = sess.run([out1_op, out2_op], + feed_dict={input_op: np.linspace(1, 10, 10)}) + self.assertAllEqual(out1, np.linspace(2, 11, 10)) + self.assertAllEqual(out2, np.linspace(2, 11, 10)) + class FunctionsFromProtos(test.TestCase): diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 2f97abdc79..19c5d3c3ea 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -904,6 +904,7 @@ class VariableScope(object): dtype=None, initializer=None, regularizer=None, + reuse=None, trainable=True, collections=None, caching_device=None, @@ -920,6 +921,8 @@ class VariableScope(object): partitioner = self._partitioner if custom_getter is None: custom_getter = self._custom_getter + if reuse is None: + reuse = self._reuse full_name = self.name + "/" + name if self.name else name # Variable names only depend on variable_scope (full_name here), @@ -942,7 +945,7 @@ class VariableScope(object): return var_store.get_variable( full_name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, reuse=self.reuse, trainable=trainable, + regularizer=regularizer, reuse=reuse, trainable=trainable, collections=collections, caching_device=caching_device, partitioner=partitioner, validate_shape=validate_shape, use_resource=use_resource, custom_getter=custom_getter) -- cgit v1.2.3