aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-08 16:11:08 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-08 17:17:45 -0700
commit1d2aa9451d920ad4bc8ad1ac86c06863b590e81f (patch)
treefea78c512adc2d68dd04cc1d72c50ec19e007fec
parent5e8d0251be9ffa06b83e8079a3214fe79d00d36a (diff)
Removes a few limitation of Defun:
* Defun can appear anyway. Previously, it must be under a with tf.Graph() scope; Now, it can be used to decorate module level functions. * Defun constructs the function definition lazily so that if a Defun is not called, it cost little; * Defun can be used within another Defun. The inner ones are lifted to the top-level. Change: 132620258
-rw-r--r--tensorflow/python/framework/function.py294
-rw-r--r--tensorflow/python/framework/function_test.py218
-rw-r--r--tensorflow/python/framework/ops.py56
-rw-r--r--tensorflow/python/ops/gradients.py4
-rw-r--r--tensorflow/python/ops/gradients_test.py3
-rw-r--r--tensorflow/python/training/saver_test.py8
6 files changed, 362 insertions, 221 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index c31e5c9079..49bd7f791c 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -120,7 +120,7 @@ def _add_op_node(graph, op, func):
node.op = op.type
# pylint: disable=protected-access
if graph._is_function(op.type):
- op_def = graph._get_function(op.type).signature
+ op_def = graph._get_function(op.type).definition.signature
else:
op_def = op_def_registry.get_registered_ops()[op.type]
# pylint: enable=protected-access
@@ -196,27 +196,23 @@ def graph_to_function_def(graph, name, inputs, outputs):
func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs])
func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs])
func_arg_placeholders = set([i.name for i in inputs])
- g = ops.get_default_graph()
for op in graph.get_operations():
tensor_name = op.values()[0].name
if tensor_name not in func_arg_placeholders:
- _add_op_node(g, op, func)
+ _add_op_node(graph, op, func)
return func
-def call_function(func_def, *inputs, **kwargs):
- """Calls the function described by `func_def`.
+def call_function(func, *inputs, **kwargs):
+ """Calls the function described by `func`.
This adds a `call` op to the default graph that calls the function described
- by `func_def` with the tensors listed in `inputs` as arguments. It returns
+ by `func` with the tensors listed in `inputs` as arguments. It returns
the outputs of the call, which are one or more tensors.
- `func_def` is a
- [`FunctionDef`](
- https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
- protcol buffer describing a
- TensorFlow function. See [`define_function()`](#define_function) for an
- easy way to create one from a Python function.
+ `func` is a `_DefinedFunction` object. See
+ [`define_function()`](#define_function) for an easy way to create
+ one from a Python function.
You can pass an optional keyword parameter `name=string` to name the
added operation.
@@ -224,21 +220,23 @@ def call_function(func_def, *inputs, **kwargs):
You can pass an optional keyword parameter `noinline=True|False` to instruct
the runtime not to inline the function body into the call site.
- `func_def` is automatically added to the function library of the graph if
+ `func` is automatically added to the function library of the graph if
needed.
Args:
- func_def: A `FunctionDef` protocol buffer.
+ func: A `_DefinedFunction` object.
*inputs: A list of tensors
**kwargs: Optional keyword arguments. Can only contain 'name' or
'noinline'.
Returns:
- A list of tensors representing the outputs of the call to `func_def`.
+ A list of tensors representing the outputs of the call to `func`.
Raises:
ValueError: if the arguments are invalid.
+
"""
+ func_def = func.definition
name = kwargs.pop("name", None)
noinline = kwargs.pop("noinline", None)
if noinline is None:
@@ -356,70 +354,170 @@ def define_function(func, input_types, func_name=None, grad_func=None,
ValueError: if the arguments are invalid.
"""
- # TODO(touts): Lift the limitation that func can only receive Tensor args.
- func_name = func_name or _get_func_name(func)
- grad_func_name = _get_func_name(grad_func) if grad_func is not None else None
-
- argspec = inspect.getargspec(func)
- if argspec.keywords or argspec.defaults:
- raise ValueError("Functions with argument defaults or keyword "
- "arguments are not supported.")
- if inspect.isfunction(func):
- if argspec.varargs and (
- len(argspec.args) > len(input_types)) or not argspec.varargs and (
- len(argspec.args) != len(input_types)):
- raise ValueError("The function has fewer arguments "
- "than the number of specified input types.")
- argnames = argspec.args
- elif inspect.ismethod(func):
- if argspec.varargs and (
- len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and (
- len(argspec.args) != 1 + len(input_types)):
- raise ValueError("The class function has fewer arguments "
- "than the number of specified input types.")
- # 1st argument is the "class" type.
- argnames = argspec.args[1:]
-
- args = []
- if isinstance(input_types, (list, tuple)):
- for i in range(len(input_types)):
- argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
- argtype = input_types[i]
- args.append((argname, argtype))
- else:
- for name in argnames:
- if name not in input_types:
- raise ValueError("Missing type for argument: " + name)
- args.append((name, input_types[name]))
-
- # Create the func_def object.
- temp_graph = ops.Graph()
- with temp_graph.as_default():
- # List of placeholders for the function_def.
- inputs = []
- # Arglist to call 'func'
- kwargs = {}
- for (argname, argtype) in args:
- argholder = array_ops.placeholder(argtype, name=argname)
- inputs.append(argholder)
- kwargs[argname] = argholder
- # Call func and gather the output tensors.
- if isinstance(input_types, (list, tuple)):
- outputs = func(*inputs)
- else:
- outputs = func(**kwargs)
- if not isinstance(outputs, ops.Tensor) and not outputs:
- raise ValueError("Function must return at least one tensor")
- # Convenience: if func only returned one value, make it a tuple.
- if not isinstance(outputs, (list, tuple)):
- outputs = (outputs,)
- # Build the FunctionDef
- func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs)
- g = ops.get_default_graph()
+ f = _DefinedFunction(func, input_types, func_name, grad_func,
+ python_grad_func)
# pylint: disable=protected-access
- g._add_function(func_def, grad_func_name, python_grad_func=python_grad_func)
+ f.add_to_graph(ops.get_default_graph())
# pylint: enable=protected-access
- return func_def
+ return f
+
+
+class _DefinedFunction(object):
+ """_DefinedFunction encapsulates a function definition and its properties.
+
+ Attributes:
+ name: The function name.
+ definition: The definition of this function. A FunctionDef proto.
+ grad_func_name: If not None, the name of this function's gradient function.
+ python_grad_func: A python callable implementing the gradient of
+ the function python-side.
+ """
+
+ def __init__(self, func, input_types, func_name=None, grad_func=None,
+ python_grad_func=None):
+ """Creates _DefinedFunction.
+
+ Args:
+ func: A python callable which constructs a tf function body.
+ input_types: The function's argument types. Can be a tuple, list of
+ tf data types, or a dictionary of argument names to their types.
+ func_name: The function name. Defaults to None, in which derives from
+ 'func'.
+ grad_func: This function's gradient function, if not None. Defaults
+ to None.
+ python_grad_func: A python callable implementing the gradient of
+ the function python-side.
+
+ Raises:
+ ValueError: The function definition is invalid.
+ """
+ self._func = func
+ self._input_types = input_types
+ self._func_name = func_name or _get_func_name(func)
+ self._grad_func = grad_func
+ self._python_grad_func = python_grad_func
+ self._definition = None # Constructed lazily.
+
+ argspec = inspect.getargspec(func)
+ if argspec.keywords or argspec.defaults:
+ raise ValueError("Functions with argument defaults or keyword "
+ "arguments are not supported.")
+ if inspect.isfunction(func):
+ if argspec.varargs and (
+ len(argspec.args) > len(input_types)) or not argspec.varargs and (
+ len(argspec.args) != len(input_types)):
+ raise ValueError("The function has fewer arguments "
+ "than the number of specified input types.")
+ argnames = argspec.args
+ elif inspect.ismethod(func):
+ if argspec.varargs and (
+ len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and (
+ len(argspec.args) != 1 + len(input_types)):
+ raise ValueError("The class function has fewer arguments "
+ "than the number of specified input types.")
+ # 1st argument is the "class" type.
+ argnames = argspec.args[1:]
+
+ self._args = []
+ if isinstance(input_types, (list, tuple)):
+ for i in range(len(input_types)):
+ argname = argnames[i] if i < len(argnames) else ("arg%d" % i)
+ argtype = input_types[i]
+ self._args.append((argname, argtype))
+ else:
+ for name in argnames:
+ if name not in input_types:
+ raise ValueError("Missing type for argument: " + name)
+ self._args.append((name, input_types[name]))
+
+ @property
+ def name(self):
+ """Function name."""
+ return self._func_name
+
+ @property
+ def definition(self):
+ """Function definition proto."""
+ self._create_definition_if_needed()
+ return self._definition
+
+ @property
+ def grad_func_name(self):
+ """Its gradient function's name."""
+ return self._grad_func.name if self._grad_func else None
+
+ @property
+ def python_grad_func(self):
+ """Python gradient function callable."""
+ return self._python_grad_func
+
+ def _create_definition_if_needed(self):
+ """Creates the function definition if it's not created yet."""
+
+ if self._definition is not None:
+ return
+
+ # Create the func_def object.
+ temp_graph = ops.Graph()
+ with temp_graph.as_default():
+ # List of placeholders for the function_def.
+ inputs = []
+ # Arglist to call 'func'
+ kwargs = {}
+ for (argname, argtype) in self._args:
+ argholder = array_ops.placeholder(argtype, name=argname)
+ inputs.append(argholder)
+ kwargs[argname] = argholder
+ # Call func and gather the output tensors.
+ if isinstance(self._input_types, (list, tuple)):
+ outputs = self._func(*inputs)
+ else:
+ outputs = self._func(**kwargs)
+ if not isinstance(outputs, ops.Tensor) and not outputs:
+ raise ValueError("Function must return at least one tensor")
+ # Convenience: if func only returned one value, make it a tuple.
+ if not isinstance(outputs, (list, tuple)):
+ outputs = (outputs,)
+
+ # Build the FunctionDef
+ self._definition = graph_to_function_def(
+ temp_graph, self._func_name, inputs, outputs)
+ # pylint: disable=protected-access
+ self._sub_functions = temp_graph._functions
+ # pylint: enable=protected-access
+ self._hash = hash(self._definition.SerializeToString())
+ for item in self._sub_functions.items(): # OrderedDict
+ self._hash = hash((self._hash, item))
+
+ def __hash__(self):
+ self._create_definition_if_needed()
+ return self._hash
+
+ def add_to_graph(self, g):
+ """Adds this function into the graph g."""
+ self._create_definition_if_needed()
+
+ # pylint: disable=protected-access
+ # If 'g' has an identical function already, do nothing.
+ prev = g._get_function(self.name)
+ if prev and (prev._hash == self._hash):
+ return
+
+ # Adds this function into 'g'.
+ g._add_function(self)
+ # pylint: enable=protected-access
+
+ # Ensures related sub-routines are defined in 'g', too.
+ for f in self._sub_functions.values():
+ f.add_to_graph(g)
+
+ # Adds its gradient function, too.
+ if self._grad_func:
+ self._grad_func.add_to_graph(g)
+
+ def __call__(self, *args, **kwargs):
+ self.add_to_graph(ops.get_default_graph())
+ return call_function(self, *args, **kwargs)
class Defun(object):
@@ -440,7 +538,10 @@ class Defun(object):
def foo(x, y):
...
- When you call the decorated function it will add `call` ops to the graph.
+ When you call the decorated function it will add `call` ops to the
+ default graph and adds the definition of the function into the
+ default graph. Because the addition of the function into the graph
+ is deferred, the decorator can be used anywhere in the program.
Example, but also see the [How To on functions](link_needed).
@@ -457,6 +558,7 @@ class Defun(object):
```
@@__init__
+
"""
def __init__(self, *input_type_list, **input_types):
@@ -491,35 +593,9 @@ class Defun(object):
def __call__(self, f):
if self._input_types:
- func_def = define_function(
- f, self._input_types,
- func_name=self._func_name, grad_func=self._grad_func,
- python_grad_func=self._python_grad_func)
+ inp_types = self._input_types
else:
- func_def = define_function(
- f, self._input_type_list,
- func_name=self._func_name, grad_func=self._grad_func,
- python_grad_func=self._python_grad_func)
-
- return _DefinedFunction(definition=func_def)
-
-
-class _DefinedFunction(object):
- """Class to store the name and definition of the function defined by Defun.
-
- This object implements a callable interface that runs `call_function`, and
- provides a `name` property to look up the name of the `Function`.
-
- An instance of `_DefinedFunction` may be passed to the `grad_func` parameter
- of `define_function` and `Defun`.
- """
-
- def __init__(self, definition):
- self._definition = definition
-
- @property
- def name(self):
- return self._definition.signature.name
-
- def __call__(self, *args, **kwargs):
- return call_function(self._definition, *args, **kwargs)
+ inp_types = self._input_type_list
+ return _DefinedFunction(
+ f, inp_types, self._func_name, self._grad_func,
+ self._python_grad_func)
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 00b1d58c6c..bb69295499 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import time
import numpy as np
@@ -53,7 +54,6 @@ class FunctionTest(tf.test.TestCase):
# u = matmul(a, b) + c
# v = u^2
# w = u + v
- # TODO(zhifengc): replaces w/ a nicer @decorator sugar.
foo = tf.Graph()
with foo.as_default():
a = tf.placeholder(tf.float32, name="a")
@@ -63,7 +63,17 @@ class FunctionTest(tf.test.TestCase):
v = tf.square(u, name="v")
w = tf.add_n([u, v], name="w")
fdef = function.graph_to_function_def(foo, "foo", [a, b, c], [u, v, w])
- g._add_function(fdef)
+
+ class Mock(function._DefinedFunction):
+
+ def __init__(self, fdef):
+ self._func_name = "foo"
+ self._definition = fdef
+ self._sub_functions = collections.OrderedDict()
+ self._grad_func = None
+ self._python_grad_func = None
+ self._hash = hash(fdef.SerializeToString())
+ g._add_function(Mock(fdef))
# Compute 2 * 3 + 4 and its square.
with g.as_default(), tf.Session() as sess:
@@ -105,11 +115,11 @@ class FunctionTest(tf.test.TestCase):
return a + b * 2
with tf.Graph().as_default():
- f_def = function.define_function(APlus2B, {"a": tf.float32,
- "b": tf.float32})
+ f = function.define_function(APlus2B, {"a": tf.float32,
+ "b": tf.float32})
one = tf.constant([1.0])
two = tf.constant([2.0])
- call = function.call_function(f_def, one, two)
+ call = function.call_function(f, one, two)
self.assertEquals("APlus2B", call.op.name)
with tf.Session() as sess:
self.assertAllEqual([5.0], sess.run(call))
@@ -146,11 +156,12 @@ class FunctionTest(tf.test.TestCase):
self.assertAllClose([0.4], sess.run(call_g))
def testTanhSymGrad(self):
+ @function.Defun(tf.float32)
+ def Forward(x):
+ return tf.reduce_sum(tf.tanh(x))
+
g = tf.Graph()
with g.as_default():
- @function.Defun(tf.float32)
- def Forward(x):
- return tf.reduce_sum(tf.tanh(x))
x = tf.placeholder(tf.float32)
y = Forward(x)
dx = tf.gradients([y], [x])
@@ -167,21 +178,21 @@ class FunctionTest(tf.test.TestCase):
self.assertAllClose(1 - np.square(np.tanh(inp)), out)
def testCustomGradient(self):
- g = tf.Graph()
dtype = tf.float32
- with g.as_default():
- @function.Defun(dtype, dtype, dtype)
- def XentLossGrad(logits, labels, dloss):
- dlogits = tf.reshape(dloss, [-1, 1]) * (tf.nn.softmax(logits) - labels)
- dlabels = tf.zeros_like(labels)
- # Takes exp(dlogits) to differentiate it from the "correct" gradient.
- return tf.exp(dlogits), dlabels
+ @function.Defun(dtype, dtype, dtype)
+ def XentLossGrad(logits, labels, dloss):
+ dlogits = tf.reshape(dloss, [-1, 1]) * (tf.nn.softmax(logits) - labels)
+ dlabels = tf.zeros_like(labels)
+ # Takes exp(dlogits) to differentiate it from the "correct" gradient.
+ return tf.exp(dlogits), dlabels
- @function.Defun(dtype, dtype, grad_func=XentLossGrad)
- def XentLoss(logits, labels):
- return tf.reduce_sum(labels * tf.log(tf.nn.softmax(logits)), 1)
+ @function.Defun(dtype, dtype, grad_func=XentLossGrad)
+ def XentLoss(logits, labels):
+ return tf.reduce_sum(labels * tf.log(tf.nn.softmax(logits)), 1)
+ g = tf.Graph()
+ with g.as_default():
logits = tf.placeholder(dtype)
labels = tf.placeholder(dtype)
loss = XentLoss(logits, labels)
@@ -197,19 +208,19 @@ class FunctionTest(tf.test.TestCase):
self.assertAllClose(out, np.exp(prob - y))
def testCustomGradientError(self):
- g = tf.Graph()
dtype = tf.float32
- with g.as_default():
- @function.Defun(dtype, dtype, dtype)
- def Grad(x, dy, dz):
- # Should have returned 1 result.
- return x, dy + dz
+ @function.Defun(dtype, dtype, dtype)
+ def Grad(x, dy, dz):
+ # Should have returned 1 result.
+ return x, dy + dz
- @function.Defun(dtype, grad_func=Grad)
- def Forward(x):
- return x, x
+ @function.Defun(dtype, grad_func=Grad)
+ def Forward(x):
+ return x, x
+ g = tf.Graph()
+ with g.as_default():
inp = tf.placeholder(dtype)
out = tf.add_n(Forward(inp))
dinp = tf.gradients(out, [inp])
@@ -238,11 +249,12 @@ class FunctionTest(tf.test.TestCase):
self.assertEquals(y.get_shape(), dy.get_shape())
def testZNoDepOnY(self):
+ @function.Defun(tf.float32, tf.float32)
+ def Foo(x, y):
+ return x * 2
+
with tf.Graph().as_default():
# z = Foo(x, y). z doe
- @function.Defun(tf.float32, tf.float32)
- def Foo(x, y):
- return x * 2
x = tf.constant(1.0)
y = tf.constant(2.0)
z = Foo(x, y)
@@ -297,23 +309,25 @@ class FunctionTest(tf.test.TestCase):
return a + b, b - a
with tf.Graph().as_default():
+ # pylint: disable=expression-not-assigned
with self.assertRaisesRegexp(ValueError, "return at least one tensor"):
- function.define_function(NoResult, {})
+ function.define_function(NoResult, {}).definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
- function.define_function(DefaultArg, {})
+ function.define_function(DefaultArg, {}).definition
with self.assertRaisesRegexp(ValueError, "are not supported"):
- function.define_function(KwArgs, {})
+ function.define_function(KwArgs, {}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
- function.define_function(PlusMinus, {})
+ function.define_function(PlusMinus, {}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
- function.define_function(PlusMinus, {"c": tf.float32})
+ function.define_function(PlusMinus, {"c": tf.float32}).definition
with self.assertRaisesRegexp(ValueError, "type for argument: b"):
function.define_function(PlusMinus, {"a": tf.float32,
- "c": tf.float32})
+ "c": tf.float32}).definition
with self.assertRaisesRegexp(ValueError, "specified input types"):
function.define_function(PlusMinus, {"a": tf.float32,
"b": tf.float32,
- "c": tf.float32})
+ "c": tf.float32}).definition
+ # pylint: enable=expression-not-assigned
def testCallErrors(self):
@@ -356,41 +370,89 @@ class FunctionTest(tf.test.TestCase):
with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"):
function.call_function(plus_one, one, device="/gpu:0")
- def testFunctionDecorator(self):
-
+ def testDupDefinition(self):
+ @function.Defun(tf.float32)
+ def Foo(x):
+ return x + 1
+ @function.Defun(tf.float32, func_name="Foo")
+ def Bar(x):
+ return x + 1
+ @function.Defun(tf.float32, func_name="Foo")
+ def Baz(x):
+ return x + 2
with tf.Graph().as_default():
+ x = tf.constant(100.0)
+ y = Foo(x)
+ z = Bar(x) # OK.
+ with self.test_session():
+ self.assertAllEqual(y.eval(), z.eval())
+ with self.assertRaisesRegexp(ValueError, "already defined"):
+ z = Baz(x)
- @function.Defun(tf.float32)
- def Minus1(b):
- return b - 1.0
+ def testFunctionDecorator(self):
+ @function.Defun(tf.float32)
+ def Minus1(b):
+ return b - 1.0
+ with tf.Graph().as_default():
two = tf.constant([2.])
call1 = Minus1(two)
self.assertTrue(isinstance(Minus1, function._DefinedFunction))
self.assertEqual(Minus1.name, "Minus1")
# pylint: disable=unexpected-keyword-arg
call2 = Minus1(call1, name="next")
- # pylint:enable=unexpected-keyword-arg
+ # pylint: enable=unexpected-keyword-arg
self.assertEquals("next", call2.op.name)
with tf.Session() as sess:
self.assertAllEqual([1], sess.run(call1))
self.assertAllEqual([0], sess.run(call2))
def testNestedFunction(self):
+ @function.Defun(tf.float32)
+ def Cube(x):
+ return x * x * x
+
+ @function.Defun(tf.float32, tf.float32)
+ def CubeXPlusY(x, y):
+ return Cube(x) + y
+
with tf.Graph().as_default():
+ z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0))
+ with self.test_session():
+ self.assertAllEqual(z.eval(), 25.0)
+ def testNestedDefinedFunction(self):
+ @function.Defun(tf.float32, tf.float32)
+ def CubeXPlusY(x, y):
@function.Defun(tf.float32)
def Cube(x):
return x * x * x
-
- @function.Defun(tf.float32, tf.float32)
- def CubeXPlusY(x, y):
- return Cube(x) + y
-
+ return Cube(x) + y
+ with tf.Graph().as_default():
z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0))
with self.test_session():
self.assertAllEqual(z.eval(), 25.0)
+ def testUnusedFunction(self):
+ invoked = False
+ # pylint: disable=unused-variable
+ @function.Defun()
+ def Unused():
+ invoked = True
+ return tf.constant(42.)
+ self.assertFalse(invoked)
+ g = tf.Graph()
+ with g.as_default():
+ @function.Defun()
+ def Unused2():
+ invoked = True
+ return tf.constant(7.)
+ tf.constant(3.)
+ # pylint: enable=unused-variable
+ self.assertFalse(invoked)
+ gdef = g.as_graph_def()
+ self.assertEquals(0, len(gdef.library.function))
+
def testReduction(self):
g = tf.Graph()
@@ -400,17 +462,19 @@ class FunctionTest(tf.test.TestCase):
var = tf.reduce_mean(tf.square(x - mean)) # biased var
rstd = tf.rsqrt(var + 1e-8)
return (x - mean) * rstd
- with g.as_default():
- # Wraps BatchNorm in a tf function.
- @function.Defun(tf.float32)
- def BN1(x):
- return BN0(x)
+ # Wraps BatchNorm in a tf function.
+ @function.Defun(tf.float32)
+ def BN1(x):
+ return BN0(x)
+
+ with g.as_default():
x = tf.placeholder(tf.float32)
y0 = BN0(x) # A plain graph
y1 = BN1(x) # A tf function
dx0, = tf.gradients([y0], [x])
dx1, = tf.gradients([y1], [x])
+
# Both should produce the same result and gradient.
with self.test_session(graph=g) as sess:
vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))})
@@ -565,21 +629,20 @@ class FunctionInlineControlTest(tf.test.TestCase):
do_function_inlining=True,
do_constant_folding=True)))
for noinline in [False, True]:
- g = tf.Graph()
- with g.as_default():
+ @function.Defun(dtype)
+ def Cell(v):
+ # If v is a vector [n, 1], x is a big square matrix.
+ x = tf.tanh(v + tf.transpose(v, [1, 0]))
+ return tf.reduce_sum(x, 1, keep_dims=True)
- @function.Defun(dtype)
- def Cell(v):
- # If v is a vector [n, 1], x is a big square matrix.
- x = tf.tanh(v + tf.transpose(v, [1, 0]))
- return tf.reduce_sum(x, 1, keep_dims=True)
-
- @function.Defun(dtype)
- def Forward(x):
- for _ in range(10):
- x = Cell(x, noinline=noinline)
- return tf.reduce_sum(x, [0, 1])
+ @function.Defun(dtype)
+ def Forward(x):
+ for _ in range(10):
+ x = Cell(x, noinline=noinline)
+ return tf.reduce_sum(x, [0, 1])
+ g = tf.Graph()
+ with g.as_default():
x = tf.placeholder(dtype)
y = Forward(x)
dx, = tf.gradients([y], [x])
@@ -593,5 +656,26 @@ class FunctionInlineControlTest(tf.test.TestCase):
self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
+@function.Defun(*[tf.float32]*3)
+def Linear(w, b, x):
+ return tf.nn.relu(tf.matmul(x, w) + b)
+
+
+@function.Defun(*[tf.float32]*5)
+def Linear2(w1, b1, w2, b2, x):
+ return Linear(w2, b2, Linear(w1, b1, x))
+
+
+class ModuleFunctionTest(tf.test.TestCase):
+
+ def testBasic(self):
+ with tf.Graph().as_default():
+ a, b, c, d, e = [tf.constant([[_]], dtype=tf.float32) for _ in range(5)]
+ y = Linear(a, b, c)
+ z = Linear2(a, b, c, d, e)
+ with tf.Session() as sess:
+ self.assertAllEqual([[1]], sess.run(y))
+ self.assertAllEqual([[5]], sess.run(z))
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index a5e253eabf..06d043ffa2 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2033,8 +2033,6 @@ class Graph(object):
self._finalized = False
# Functions defined in the graph
self._functions = collections.OrderedDict()
- self._function_gradient = collections.OrderedDict()
- self._function_python_gradient = collections.OrderedDict()
# Default GraphDef versions
self._graph_def_versions = versions_pb2.VersionDef(
producer=versions.GRAPH_DEF_VERSION,
@@ -2200,15 +2198,15 @@ class Graph(object):
raise ValueError("GraphDef cannot be larger than 2GB.")
if self._functions:
for f in self._functions.values():
- bytesize += f.ByteSize()
+ bytesize += f.definition.ByteSize()
if bytesize >= (1 << 31) or bytesize < 0:
raise ValueError("GraphDef cannot be larger than 2GB.")
- graph.library.function.extend(self._functions.values())
- for func in self._function_gradient:
- grad_def = function_pb2.GradientDef()
- grad_def.function_name = func
- grad_def.gradient_func = self._function_gradient[func]
- graph.library.gradient.extend([grad_def])
+ graph.library.function.extend([f.definition])
+ if f.grad_func_name:
+ grad_def = function_pb2.GradientDef()
+ grad_def.function_name = f.name
+ grad_def.gradient_func = f.grad_func_name
+ graph.library.gradient.extend([grad_def])
return graph, self._version
def as_graph_def(self, from_version=None, add_shapes=False):
@@ -2255,49 +2253,31 @@ class Graph(object):
Returns:
The function def proto.
"""
- return self._functions[name]
+ return self._functions.get(name, None)
- def _add_function(self, function_def, grad_function_name=None,
- python_grad_func=None):
+ def _add_function(self, function):
"""Adds a function to the graph.
- The function is specified as a [`FunctionDef`]
- (https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
- protocol buffer.
-
After the function has been added, you can call to the function by
passing the function name in place of an op name to
`Graph.create_op()`.
Args:
- function_def: A `FunctionDef` protocol buffer.
- grad_function_name: If not None, this specifies the name of a function
- that shall be used as the gradient function of
- the function being added.
- python_grad_func: If not None, specifies the gradient function with the
- same interface as expected by `tf.RegisterGradient`.
- No more than one of {grad_function_name,
- python_grad_func} may be specified.
+ function: A `_DefinedFunction` object.
Raises:
ValueError: if another function is defined with the same name.
"""
- name = function_def.signature.name
- previous_def = self._functions.get(name, None)
- if previous_def:
- if previous_def != function_def:
- raise ValueError("Another function is already defined with that name")
- else:
- # No need to add again.
- return
- self._functions[name] = function_def
- if grad_function_name is not None and python_grad_func is not None:
+ name = function.name
+ previous = self._functions.get(name, None)
+ if previous:
+ raise ValueError("Another function is already defined with that name")
+ # Sanity checks on gradient definition.
+ if (function.grad_func_name is not None) and (
+ function.python_grad_func is not None):
raise ValueError("Gradient defined twice for function %s" % name)
- if grad_function_name is not None:
- self._function_gradient[name] = grad_function_name
- if python_grad_func is not None:
- self._function_python_gradient[name] = python_grad_func
+ self._functions[name] = function
# Helper functions to create operations.
def create_op(self, op_type, inputs, dtypes,
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 400ec0277a..4553ab4006 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -432,8 +432,8 @@ def gradients(ys,
has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
if has_out_grads and (op._id not in stop_ops):
if is_func_call:
- grad_fn = ops.get_default_graph()._function_python_gradient.get(
- op.type, None)
+ grad_fn = ops.get_default_graph()._get_function(
+ op.type).python_grad_func
# pylint: enable=protected-access
else:
# A grad_fn must be defined, either as a function or as None
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 91e69e8bd1..70e6b46d6a 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -376,8 +376,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase):
grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)(
self.XSquarePlusBGradient)
with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
- _ = self._GetFunc(grad_func=grad_func,
+ f = self._GetFunc(grad_func=grad_func,
python_grad_func=self._PythonGradient)
+ f.add_to_graph(tf.Graph())
class StopGradientTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 92700c9414..3f8a65ca1b 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -1520,12 +1520,12 @@ class MetaGraphTest(tf.test.TestCase):
def testStrippedOpListNestedFunctions(self):
with self.test_session():
# Square two levels deep
+ @function.Defun(tf.int32)
def f0(x):
return tf.square(x)
- f0 = function.define_function(f0, {"x": tf.int32})
+ @function.Defun(tf.int32)
def f1(x):
- return function.call_function(f0, x)
- f1 = function.define_function(f1, {"x": tf.int32})
+ return f0(x)
# At this point we've defined two functions but haven't called them, so
# there should be no used ops.
@@ -1534,7 +1534,7 @@ class MetaGraphTest(tf.test.TestCase):
self.assertEquals(len(op_list.op), 0)
# If we call the function on a constant, there should be two ops
- function.call_function(f1, tf.constant(7))
+ _ = f1(tf.constant(7))
op_list = tf.contrib.util.stripped_op_list_for_graph(
tf.get_default_graph().as_graph_def())
self.assertEquals(["Const", "Square"], [op.name for op in op_list.op])