aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-03-31 02:30:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-03-31 03:53:24 -0700
commite4a8dc831dbf2894c79659d50aea73999c1ff173 (patch)
tree3f1f21f58707b6522421a319df87f6558ac2d656
parenta68d64495ae90eada08d028d7577343429f2555a (diff)
Allow variable reuse in function.
Change: 151807504
-rw-r--r--tensorflow/python/framework/function.py7
-rw-r--r--tensorflow/python/framework/function_test.py52
-rw-r--r--tensorflow/python/ops/variable_scope.py5
3 files changed, 63 insertions, 1 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 3c663c3a9b..8b156db6dc 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -299,6 +299,7 @@ class _FuncGraph(ops.Graph):
shape=None,
dtype=None,
initializer=None,
+ reuse=None,
trainable=True,
collections=None, # pylint: disable=redefined-outer-name
use_resource=None,
@@ -319,6 +320,7 @@ class _FuncGraph(ops.Graph):
shape=shape,
dtype=dtype,
initializer=initializer,
+ reuse=reuse,
trainable=trainable,
collections=collections,
use_resource=use_resource)
@@ -886,6 +888,11 @@ class Defun(object):
default graph. Because the addition of the function into the graph
is deferred, the decorator can be used anywhere in the program.
+ Any variables created inside of the function are hoisted into the outer graph.
+ Note that the variables are created in the variable scope that was active
+ during the first call to the function. Subsequent function calls will refer to
+ the same set of variables.
+
Definitions of functions are frozen in a graph as soon as the graph is used to
create a session. Therefore, nodes using the function must be created in the
graph before the corresponding session is created.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 96bf7bde29..51a9215ac4 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -722,6 +722,58 @@ class FunctionTest(test.TestCase):
y = Bar(array_ops.zeros([1, 2, 3]))
self.assertAllEqual(y.get_shape().as_list(), [1, 1, 2, 3])
+ def testVariableReuse(self):
+ def LinearWithReuse(input_tensor, reuse=None):
+ size = input_tensor.shape.dims[1]
+ with variable_scope.variable_scope("linear", reuse=reuse):
+ w = variable_scope.get_variable("w", shape=[size, size],
+ dtype=input_tensor.dtype)
+ return math_ops.matmul(input_tensor, w)
+
+ @function.Defun(dtypes.float32)
+ def Foo(inputs):
+ inputs = array_ops.reshape(inputs, [32, 100])
+ hidden = LinearWithReuse(inputs)
+ return LinearWithReuse(hidden, reuse=True)
+
+ input_op = array_ops.placeholder(shape=[32, 100], dtype=dtypes.float32)
+ output_op = Foo(input_op)
+
+ global_vars = variables.global_variables()
+ self.assertEqual(len(global_vars), 1)
+ self.assertEqual(global_vars[0].name, "linear/w:0")
+
+ with session.Session() as sess:
+ sess.run(variables.global_variables_initializer())
+ output_val = sess.run(output_op,
+ feed_dict={input_op: np.random.rand(32, 100)})
+ self.assertEqual(output_val.shape, (32, 100))
+
+ def testFunctionCallInDifferentVariableScopes(self):
+ @function.Defun(dtypes.float32)
+ def Foo(inputs):
+ var = variable_scope.get_variable("var", shape=[10], dtype=dtypes.float32,
+ initializer=init_ops.ones_initializer())
+ return inputs + var
+
+ input_op = array_ops.placeholder(shape=[10], dtype=dtypes.float32)
+ with variable_scope.variable_scope("vs1"):
+ out1_op = Foo(input_op)
+
+ with variable_scope.variable_scope("vs2"):
+ out2_op = Foo(input_op)
+
+ global_vars = variables.global_variables()
+ self.assertEqual(len(global_vars), 1)
+ self.assertEqual(global_vars[0].name, "vs1/var:0")
+
+ with session.Session() as sess:
+ sess.run(variables.global_variables_initializer())
+ out1, out2 = sess.run([out1_op, out2_op],
+ feed_dict={input_op: np.linspace(1, 10, 10)})
+ self.assertAllEqual(out1, np.linspace(2, 11, 10))
+ self.assertAllEqual(out2, np.linspace(2, 11, 10))
+
class FunctionsFromProtos(test.TestCase):
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 2f97abdc79..19c5d3c3ea 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -904,6 +904,7 @@ class VariableScope(object):
dtype=None,
initializer=None,
regularizer=None,
+ reuse=None,
trainable=True,
collections=None,
caching_device=None,
@@ -920,6 +921,8 @@ class VariableScope(object):
partitioner = self._partitioner
if custom_getter is None:
custom_getter = self._custom_getter
+ if reuse is None:
+ reuse = self._reuse
full_name = self.name + "/" + name if self.name else name
# Variable names only depend on variable_scope (full_name here),
@@ -942,7 +945,7 @@ class VariableScope(object):
return var_store.get_variable(
full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=self.reuse, trainable=trainable,
+ regularizer=regularizer, reuse=reuse, trainable=trainable,
collections=collections, caching_device=caching_device,
partitioner=partitioner, validate_shape=validate_shape,
use_resource=use_resource, custom_getter=custom_getter)