diff options
author | 2016-09-08 16:11:08 -0800 | |
---|---|---|
committer | 2016-09-08 17:17:45 -0700 | |
commit | 1d2aa9451d920ad4bc8ad1ac86c06863b590e81f (patch) | |
tree | fea78c512adc2d68dd04cc1d72c50ec19e007fec | |
parent | 5e8d0251be9ffa06b83e8079a3214fe79d00d36a (diff) |
Removes a few limitation of Defun:
* Defun can appear anyway. Previously, it must be under a with tf.Graph() scope;
Now, it can be used to decorate module level functions.
* Defun constructs the function definition lazily so that if a Defun is not called,
it cost little;
* Defun can be used within another Defun. The inner ones are lifted to the top-level.
Change: 132620258
-rw-r--r-- | tensorflow/python/framework/function.py | 294 | ||||
-rw-r--r-- | tensorflow/python/framework/function_test.py | 218 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 56 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients.py | 4 | ||||
-rw-r--r-- | tensorflow/python/ops/gradients_test.py | 3 | ||||
-rw-r--r-- | tensorflow/python/training/saver_test.py | 8 |
6 files changed, 362 insertions, 221 deletions
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index c31e5c9079..49bd7f791c 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -120,7 +120,7 @@ def _add_op_node(graph, op, func): node.op = op.type # pylint: disable=protected-access if graph._is_function(op.type): - op_def = graph._get_function(op.type).signature + op_def = graph._get_function(op.type).definition.signature else: op_def = op_def_registry.get_registered_ops()[op.type] # pylint: enable=protected-access @@ -196,27 +196,23 @@ def graph_to_function_def(graph, name, inputs, outputs): func.signature.input_arg.extend([_tensor_to_argdef(i) for i in inputs]) func.signature.output_arg.extend([_tensor_to_argdef(o) for o in outputs]) func_arg_placeholders = set([i.name for i in inputs]) - g = ops.get_default_graph() for op in graph.get_operations(): tensor_name = op.values()[0].name if tensor_name not in func_arg_placeholders: - _add_op_node(g, op, func) + _add_op_node(graph, op, func) return func -def call_function(func_def, *inputs, **kwargs): - """Calls the function described by `func_def`. +def call_function(func, *inputs, **kwargs): + """Calls the function described by `func`. This adds a `call` op to the default graph that calls the function described - by `func_def` with the tensors listed in `inputs` as arguments. It returns + by `func` with the tensors listed in `inputs` as arguments. It returns the outputs of the call, which are one or more tensors. - `func_def` is a - [`FunctionDef`]( - https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) - protcol buffer describing a - TensorFlow function. See [`define_function()`](#define_function) for an - easy way to create one from a Python function. + `func` is a `_DefinedFunction` object. See + [`define_function()`](#define_function) for an easy way to create + one from a Python function. You can pass an optional keyword parameter `name=string` to name the added operation. @@ -224,21 +220,23 @@ def call_function(func_def, *inputs, **kwargs): You can pass an optional keyword parameter `noinline=True|False` to instruct the runtime not to inline the function body into the call site. - `func_def` is automatically added to the function library of the graph if + `func` is automatically added to the function library of the graph if needed. Args: - func_def: A `FunctionDef` protocol buffer. + func: A `_DefinedFunction` object. *inputs: A list of tensors **kwargs: Optional keyword arguments. Can only contain 'name' or 'noinline'. Returns: - A list of tensors representing the outputs of the call to `func_def`. + A list of tensors representing the outputs of the call to `func`. Raises: ValueError: if the arguments are invalid. + """ + func_def = func.definition name = kwargs.pop("name", None) noinline = kwargs.pop("noinline", None) if noinline is None: @@ -356,70 +354,170 @@ def define_function(func, input_types, func_name=None, grad_func=None, ValueError: if the arguments are invalid. """ - # TODO(touts): Lift the limitation that func can only receive Tensor args. - func_name = func_name or _get_func_name(func) - grad_func_name = _get_func_name(grad_func) if grad_func is not None else None - - argspec = inspect.getargspec(func) - if argspec.keywords or argspec.defaults: - raise ValueError("Functions with argument defaults or keyword " - "arguments are not supported.") - if inspect.isfunction(func): - if argspec.varargs and ( - len(argspec.args) > len(input_types)) or not argspec.varargs and ( - len(argspec.args) != len(input_types)): - raise ValueError("The function has fewer arguments " - "than the number of specified input types.") - argnames = argspec.args - elif inspect.ismethod(func): - if argspec.varargs and ( - len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and ( - len(argspec.args) != 1 + len(input_types)): - raise ValueError("The class function has fewer arguments " - "than the number of specified input types.") - # 1st argument is the "class" type. - argnames = argspec.args[1:] - - args = [] - if isinstance(input_types, (list, tuple)): - for i in range(len(input_types)): - argname = argnames[i] if i < len(argnames) else ("arg%d" % i) - argtype = input_types[i] - args.append((argname, argtype)) - else: - for name in argnames: - if name not in input_types: - raise ValueError("Missing type for argument: " + name) - args.append((name, input_types[name])) - - # Create the func_def object. - temp_graph = ops.Graph() - with temp_graph.as_default(): - # List of placeholders for the function_def. - inputs = [] - # Arglist to call 'func' - kwargs = {} - for (argname, argtype) in args: - argholder = array_ops.placeholder(argtype, name=argname) - inputs.append(argholder) - kwargs[argname] = argholder - # Call func and gather the output tensors. - if isinstance(input_types, (list, tuple)): - outputs = func(*inputs) - else: - outputs = func(**kwargs) - if not isinstance(outputs, ops.Tensor) and not outputs: - raise ValueError("Function must return at least one tensor") - # Convenience: if func only returned one value, make it a tuple. - if not isinstance(outputs, (list, tuple)): - outputs = (outputs,) - # Build the FunctionDef - func_def = graph_to_function_def(temp_graph, func_name, inputs, outputs) - g = ops.get_default_graph() + f = _DefinedFunction(func, input_types, func_name, grad_func, + python_grad_func) # pylint: disable=protected-access - g._add_function(func_def, grad_func_name, python_grad_func=python_grad_func) + f.add_to_graph(ops.get_default_graph()) # pylint: enable=protected-access - return func_def + return f + + +class _DefinedFunction(object): + """_DefinedFunction encapsulates a function definition and its properties. + + Attributes: + name: The function name. + definition: The definition of this function. A FunctionDef proto. + grad_func_name: If not None, the name of this function's gradient function. + python_grad_func: A python callable implementing the gradient of + the function python-side. + """ + + def __init__(self, func, input_types, func_name=None, grad_func=None, + python_grad_func=None): + """Creates _DefinedFunction. + + Args: + func: A python callable which constructs a tf function body. + input_types: The function's argument types. Can be a tuple, list of + tf data types, or a dictionary of argument names to their types. + func_name: The function name. Defaults to None, in which derives from + 'func'. + grad_func: This function's gradient function, if not None. Defaults + to None. + python_grad_func: A python callable implementing the gradient of + the function python-side. + + Raises: + ValueError: The function definition is invalid. + """ + self._func = func + self._input_types = input_types + self._func_name = func_name or _get_func_name(func) + self._grad_func = grad_func + self._python_grad_func = python_grad_func + self._definition = None # Constructed lazily. + + argspec = inspect.getargspec(func) + if argspec.keywords or argspec.defaults: + raise ValueError("Functions with argument defaults or keyword " + "arguments are not supported.") + if inspect.isfunction(func): + if argspec.varargs and ( + len(argspec.args) > len(input_types)) or not argspec.varargs and ( + len(argspec.args) != len(input_types)): + raise ValueError("The function has fewer arguments " + "than the number of specified input types.") + argnames = argspec.args + elif inspect.ismethod(func): + if argspec.varargs and ( + len(argspec.args) > 1 + len(input_types)) or not argspec.varargs and ( + len(argspec.args) != 1 + len(input_types)): + raise ValueError("The class function has fewer arguments " + "than the number of specified input types.") + # 1st argument is the "class" type. + argnames = argspec.args[1:] + + self._args = [] + if isinstance(input_types, (list, tuple)): + for i in range(len(input_types)): + argname = argnames[i] if i < len(argnames) else ("arg%d" % i) + argtype = input_types[i] + self._args.append((argname, argtype)) + else: + for name in argnames: + if name not in input_types: + raise ValueError("Missing type for argument: " + name) + self._args.append((name, input_types[name])) + + @property + def name(self): + """Function name.""" + return self._func_name + + @property + def definition(self): + """Function definition proto.""" + self._create_definition_if_needed() + return self._definition + + @property + def grad_func_name(self): + """Its gradient function's name.""" + return self._grad_func.name if self._grad_func else None + + @property + def python_grad_func(self): + """Python gradient function callable.""" + return self._python_grad_func + + def _create_definition_if_needed(self): + """Creates the function definition if it's not created yet.""" + + if self._definition is not None: + return + + # Create the func_def object. + temp_graph = ops.Graph() + with temp_graph.as_default(): + # List of placeholders for the function_def. + inputs = [] + # Arglist to call 'func' + kwargs = {} + for (argname, argtype) in self._args: + argholder = array_ops.placeholder(argtype, name=argname) + inputs.append(argholder) + kwargs[argname] = argholder + # Call func and gather the output tensors. + if isinstance(self._input_types, (list, tuple)): + outputs = self._func(*inputs) + else: + outputs = self._func(**kwargs) + if not isinstance(outputs, ops.Tensor) and not outputs: + raise ValueError("Function must return at least one tensor") + # Convenience: if func only returned one value, make it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + + # Build the FunctionDef + self._definition = graph_to_function_def( + temp_graph, self._func_name, inputs, outputs) + # pylint: disable=protected-access + self._sub_functions = temp_graph._functions + # pylint: enable=protected-access + self._hash = hash(self._definition.SerializeToString()) + for item in self._sub_functions.items(): # OrderedDict + self._hash = hash((self._hash, item)) + + def __hash__(self): + self._create_definition_if_needed() + return self._hash + + def add_to_graph(self, g): + """Adds this function into the graph g.""" + self._create_definition_if_needed() + + # pylint: disable=protected-access + # If 'g' has an identical function already, do nothing. + prev = g._get_function(self.name) + if prev and (prev._hash == self._hash): + return + + # Adds this function into 'g'. + g._add_function(self) + # pylint: enable=protected-access + + # Ensures related sub-routines are defined in 'g', too. + for f in self._sub_functions.values(): + f.add_to_graph(g) + + # Adds its gradient function, too. + if self._grad_func: + self._grad_func.add_to_graph(g) + + def __call__(self, *args, **kwargs): + self.add_to_graph(ops.get_default_graph()) + return call_function(self, *args, **kwargs) class Defun(object): @@ -440,7 +538,10 @@ class Defun(object): def foo(x, y): ... - When you call the decorated function it will add `call` ops to the graph. + When you call the decorated function it will add `call` ops to the + default graph and adds the definition of the function into the + default graph. Because the addition of the function into the graph + is deferred, the decorator can be used anywhere in the program. Example, but also see the [How To on functions](link_needed). @@ -457,6 +558,7 @@ class Defun(object): ``` @@__init__ + """ def __init__(self, *input_type_list, **input_types): @@ -491,35 +593,9 @@ class Defun(object): def __call__(self, f): if self._input_types: - func_def = define_function( - f, self._input_types, - func_name=self._func_name, grad_func=self._grad_func, - python_grad_func=self._python_grad_func) + inp_types = self._input_types else: - func_def = define_function( - f, self._input_type_list, - func_name=self._func_name, grad_func=self._grad_func, - python_grad_func=self._python_grad_func) - - return _DefinedFunction(definition=func_def) - - -class _DefinedFunction(object): - """Class to store the name and definition of the function defined by Defun. - - This object implements a callable interface that runs `call_function`, and - provides a `name` property to look up the name of the `Function`. - - An instance of `_DefinedFunction` may be passed to the `grad_func` parameter - of `define_function` and `Defun`. - """ - - def __init__(self, definition): - self._definition = definition - - @property - def name(self): - return self._definition.signature.name - - def __call__(self, *args, **kwargs): - return call_function(self._definition, *args, **kwargs) + inp_types = self._input_type_list + return _DefinedFunction( + f, inp_types, self._func_name, self._grad_func, + self._python_grad_func) diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 00b1d58c6c..bb69295499 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import time import numpy as np @@ -53,7 +54,6 @@ class FunctionTest(tf.test.TestCase): # u = matmul(a, b) + c # v = u^2 # w = u + v - # TODO(zhifengc): replaces w/ a nicer @decorator sugar. foo = tf.Graph() with foo.as_default(): a = tf.placeholder(tf.float32, name="a") @@ -63,7 +63,17 @@ class FunctionTest(tf.test.TestCase): v = tf.square(u, name="v") w = tf.add_n([u, v], name="w") fdef = function.graph_to_function_def(foo, "foo", [a, b, c], [u, v, w]) - g._add_function(fdef) + + class Mock(function._DefinedFunction): + + def __init__(self, fdef): + self._func_name = "foo" + self._definition = fdef + self._sub_functions = collections.OrderedDict() + self._grad_func = None + self._python_grad_func = None + self._hash = hash(fdef.SerializeToString()) + g._add_function(Mock(fdef)) # Compute 2 * 3 + 4 and its square. with g.as_default(), tf.Session() as sess: @@ -105,11 +115,11 @@ class FunctionTest(tf.test.TestCase): return a + b * 2 with tf.Graph().as_default(): - f_def = function.define_function(APlus2B, {"a": tf.float32, - "b": tf.float32}) + f = function.define_function(APlus2B, {"a": tf.float32, + "b": tf.float32}) one = tf.constant([1.0]) two = tf.constant([2.0]) - call = function.call_function(f_def, one, two) + call = function.call_function(f, one, two) self.assertEquals("APlus2B", call.op.name) with tf.Session() as sess: self.assertAllEqual([5.0], sess.run(call)) @@ -146,11 +156,12 @@ class FunctionTest(tf.test.TestCase): self.assertAllClose([0.4], sess.run(call_g)) def testTanhSymGrad(self): + @function.Defun(tf.float32) + def Forward(x): + return tf.reduce_sum(tf.tanh(x)) + g = tf.Graph() with g.as_default(): - @function.Defun(tf.float32) - def Forward(x): - return tf.reduce_sum(tf.tanh(x)) x = tf.placeholder(tf.float32) y = Forward(x) dx = tf.gradients([y], [x]) @@ -167,21 +178,21 @@ class FunctionTest(tf.test.TestCase): self.assertAllClose(1 - np.square(np.tanh(inp)), out) def testCustomGradient(self): - g = tf.Graph() dtype = tf.float32 - with g.as_default(): - @function.Defun(dtype, dtype, dtype) - def XentLossGrad(logits, labels, dloss): - dlogits = tf.reshape(dloss, [-1, 1]) * (tf.nn.softmax(logits) - labels) - dlabels = tf.zeros_like(labels) - # Takes exp(dlogits) to differentiate it from the "correct" gradient. - return tf.exp(dlogits), dlabels + @function.Defun(dtype, dtype, dtype) + def XentLossGrad(logits, labels, dloss): + dlogits = tf.reshape(dloss, [-1, 1]) * (tf.nn.softmax(logits) - labels) + dlabels = tf.zeros_like(labels) + # Takes exp(dlogits) to differentiate it from the "correct" gradient. + return tf.exp(dlogits), dlabels - @function.Defun(dtype, dtype, grad_func=XentLossGrad) - def XentLoss(logits, labels): - return tf.reduce_sum(labels * tf.log(tf.nn.softmax(logits)), 1) + @function.Defun(dtype, dtype, grad_func=XentLossGrad) + def XentLoss(logits, labels): + return tf.reduce_sum(labels * tf.log(tf.nn.softmax(logits)), 1) + g = tf.Graph() + with g.as_default(): logits = tf.placeholder(dtype) labels = tf.placeholder(dtype) loss = XentLoss(logits, labels) @@ -197,19 +208,19 @@ class FunctionTest(tf.test.TestCase): self.assertAllClose(out, np.exp(prob - y)) def testCustomGradientError(self): - g = tf.Graph() dtype = tf.float32 - with g.as_default(): - @function.Defun(dtype, dtype, dtype) - def Grad(x, dy, dz): - # Should have returned 1 result. - return x, dy + dz + @function.Defun(dtype, dtype, dtype) + def Grad(x, dy, dz): + # Should have returned 1 result. + return x, dy + dz - @function.Defun(dtype, grad_func=Grad) - def Forward(x): - return x, x + @function.Defun(dtype, grad_func=Grad) + def Forward(x): + return x, x + g = tf.Graph() + with g.as_default(): inp = tf.placeholder(dtype) out = tf.add_n(Forward(inp)) dinp = tf.gradients(out, [inp]) @@ -238,11 +249,12 @@ class FunctionTest(tf.test.TestCase): self.assertEquals(y.get_shape(), dy.get_shape()) def testZNoDepOnY(self): + @function.Defun(tf.float32, tf.float32) + def Foo(x, y): + return x * 2 + with tf.Graph().as_default(): # z = Foo(x, y). z doe - @function.Defun(tf.float32, tf.float32) - def Foo(x, y): - return x * 2 x = tf.constant(1.0) y = tf.constant(2.0) z = Foo(x, y) @@ -297,23 +309,25 @@ class FunctionTest(tf.test.TestCase): return a + b, b - a with tf.Graph().as_default(): + # pylint: disable=expression-not-assigned with self.assertRaisesRegexp(ValueError, "return at least one tensor"): - function.define_function(NoResult, {}) + function.define_function(NoResult, {}).definition with self.assertRaisesRegexp(ValueError, "are not supported"): - function.define_function(DefaultArg, {}) + function.define_function(DefaultArg, {}).definition with self.assertRaisesRegexp(ValueError, "are not supported"): - function.define_function(KwArgs, {}) + function.define_function(KwArgs, {}).definition with self.assertRaisesRegexp(ValueError, "specified input types"): - function.define_function(PlusMinus, {}) + function.define_function(PlusMinus, {}).definition with self.assertRaisesRegexp(ValueError, "specified input types"): - function.define_function(PlusMinus, {"c": tf.float32}) + function.define_function(PlusMinus, {"c": tf.float32}).definition with self.assertRaisesRegexp(ValueError, "type for argument: b"): function.define_function(PlusMinus, {"a": tf.float32, - "c": tf.float32}) + "c": tf.float32}).definition with self.assertRaisesRegexp(ValueError, "specified input types"): function.define_function(PlusMinus, {"a": tf.float32, "b": tf.float32, - "c": tf.float32}) + "c": tf.float32}).definition + # pylint: enable=expression-not-assigned def testCallErrors(self): @@ -356,41 +370,89 @@ class FunctionTest(tf.test.TestCase): with self.assertRaisesRegexp(ValueError, "Unknown keyword arguments"): function.call_function(plus_one, one, device="/gpu:0") - def testFunctionDecorator(self): - + def testDupDefinition(self): + @function.Defun(tf.float32) + def Foo(x): + return x + 1 + @function.Defun(tf.float32, func_name="Foo") + def Bar(x): + return x + 1 + @function.Defun(tf.float32, func_name="Foo") + def Baz(x): + return x + 2 with tf.Graph().as_default(): + x = tf.constant(100.0) + y = Foo(x) + z = Bar(x) # OK. + with self.test_session(): + self.assertAllEqual(y.eval(), z.eval()) + with self.assertRaisesRegexp(ValueError, "already defined"): + z = Baz(x) - @function.Defun(tf.float32) - def Minus1(b): - return b - 1.0 + def testFunctionDecorator(self): + @function.Defun(tf.float32) + def Minus1(b): + return b - 1.0 + with tf.Graph().as_default(): two = tf.constant([2.]) call1 = Minus1(two) self.assertTrue(isinstance(Minus1, function._DefinedFunction)) self.assertEqual(Minus1.name, "Minus1") # pylint: disable=unexpected-keyword-arg call2 = Minus1(call1, name="next") - # pylint:enable=unexpected-keyword-arg + # pylint: enable=unexpected-keyword-arg self.assertEquals("next", call2.op.name) with tf.Session() as sess: self.assertAllEqual([1], sess.run(call1)) self.assertAllEqual([0], sess.run(call2)) def testNestedFunction(self): + @function.Defun(tf.float32) + def Cube(x): + return x * x * x + + @function.Defun(tf.float32, tf.float32) + def CubeXPlusY(x, y): + return Cube(x) + y + with tf.Graph().as_default(): + z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0)) + with self.test_session(): + self.assertAllEqual(z.eval(), 25.0) + def testNestedDefinedFunction(self): + @function.Defun(tf.float32, tf.float32) + def CubeXPlusY(x, y): @function.Defun(tf.float32) def Cube(x): return x * x * x - - @function.Defun(tf.float32, tf.float32) - def CubeXPlusY(x, y): - return Cube(x) + y - + return Cube(x) + y + with tf.Graph().as_default(): z = CubeXPlusY(tf.constant(3.0), tf.constant(-2.0)) with self.test_session(): self.assertAllEqual(z.eval(), 25.0) + def testUnusedFunction(self): + invoked = False + # pylint: disable=unused-variable + @function.Defun() + def Unused(): + invoked = True + return tf.constant(42.) + self.assertFalse(invoked) + g = tf.Graph() + with g.as_default(): + @function.Defun() + def Unused2(): + invoked = True + return tf.constant(7.) + tf.constant(3.) + # pylint: enable=unused-variable + self.assertFalse(invoked) + gdef = g.as_graph_def() + self.assertEquals(0, len(gdef.library.function)) + def testReduction(self): g = tf.Graph() @@ -400,17 +462,19 @@ class FunctionTest(tf.test.TestCase): var = tf.reduce_mean(tf.square(x - mean)) # biased var rstd = tf.rsqrt(var + 1e-8) return (x - mean) * rstd - with g.as_default(): - # Wraps BatchNorm in a tf function. - @function.Defun(tf.float32) - def BN1(x): - return BN0(x) + # Wraps BatchNorm in a tf function. + @function.Defun(tf.float32) + def BN1(x): + return BN0(x) + + with g.as_default(): x = tf.placeholder(tf.float32) y0 = BN0(x) # A plain graph y1 = BN1(x) # A tf function dx0, = tf.gradients([y0], [x]) dx1, = tf.gradients([y1], [x]) + # Both should produce the same result and gradient. with self.test_session(graph=g) as sess: vals = sess.run([y0, y1, dx0, dx1], {x: np.random.uniform(size=(3, 7))}) @@ -565,21 +629,20 @@ class FunctionInlineControlTest(tf.test.TestCase): do_function_inlining=True, do_constant_folding=True))) for noinline in [False, True]: - g = tf.Graph() - with g.as_default(): + @function.Defun(dtype) + def Cell(v): + # If v is a vector [n, 1], x is a big square matrix. + x = tf.tanh(v + tf.transpose(v, [1, 0])) + return tf.reduce_sum(x, 1, keep_dims=True) - @function.Defun(dtype) - def Cell(v): - # If v is a vector [n, 1], x is a big square matrix. - x = tf.tanh(v + tf.transpose(v, [1, 0])) - return tf.reduce_sum(x, 1, keep_dims=True) - - @function.Defun(dtype) - def Forward(x): - for _ in range(10): - x = Cell(x, noinline=noinline) - return tf.reduce_sum(x, [0, 1]) + @function.Defun(dtype) + def Forward(x): + for _ in range(10): + x = Cell(x, noinline=noinline) + return tf.reduce_sum(x, [0, 1]) + g = tf.Graph() + with g.as_default(): x = tf.placeholder(dtype) y = Forward(x) dx, = tf.gradients([y], [x]) @@ -593,5 +656,26 @@ class FunctionInlineControlTest(tf.test.TestCase): self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3) +@function.Defun(*[tf.float32]*3) +def Linear(w, b, x): + return tf.nn.relu(tf.matmul(x, w) + b) + + +@function.Defun(*[tf.float32]*5) +def Linear2(w1, b1, w2, b2, x): + return Linear(w2, b2, Linear(w1, b1, x)) + + +class ModuleFunctionTest(tf.test.TestCase): + + def testBasic(self): + with tf.Graph().as_default(): + a, b, c, d, e = [tf.constant([[_]], dtype=tf.float32) for _ in range(5)] + y = Linear(a, b, c) + z = Linear2(a, b, c, d, e) + with tf.Session() as sess: + self.assertAllEqual([[1]], sess.run(y)) + self.assertAllEqual([[5]], sess.run(z)) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index a5e253eabf..06d043ffa2 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -2033,8 +2033,6 @@ class Graph(object): self._finalized = False # Functions defined in the graph self._functions = collections.OrderedDict() - self._function_gradient = collections.OrderedDict() - self._function_python_gradient = collections.OrderedDict() # Default GraphDef versions self._graph_def_versions = versions_pb2.VersionDef( producer=versions.GRAPH_DEF_VERSION, @@ -2200,15 +2198,15 @@ class Graph(object): raise ValueError("GraphDef cannot be larger than 2GB.") if self._functions: for f in self._functions.values(): - bytesize += f.ByteSize() + bytesize += f.definition.ByteSize() if bytesize >= (1 << 31) or bytesize < 0: raise ValueError("GraphDef cannot be larger than 2GB.") - graph.library.function.extend(self._functions.values()) - for func in self._function_gradient: - grad_def = function_pb2.GradientDef() - grad_def.function_name = func - grad_def.gradient_func = self._function_gradient[func] - graph.library.gradient.extend([grad_def]) + graph.library.function.extend([f.definition]) + if f.grad_func_name: + grad_def = function_pb2.GradientDef() + grad_def.function_name = f.name + grad_def.gradient_func = f.grad_func_name + graph.library.gradient.extend([grad_def]) return graph, self._version def as_graph_def(self, from_version=None, add_shapes=False): @@ -2255,49 +2253,31 @@ class Graph(object): Returns: The function def proto. """ - return self._functions[name] + return self._functions.get(name, None) - def _add_function(self, function_def, grad_function_name=None, - python_grad_func=None): + def _add_function(self, function): """Adds a function to the graph. - The function is specified as a [`FunctionDef`] - (https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) - protocol buffer. - After the function has been added, you can call to the function by passing the function name in place of an op name to `Graph.create_op()`. Args: - function_def: A `FunctionDef` protocol buffer. - grad_function_name: If not None, this specifies the name of a function - that shall be used as the gradient function of - the function being added. - python_grad_func: If not None, specifies the gradient function with the - same interface as expected by `tf.RegisterGradient`. - No more than one of {grad_function_name, - python_grad_func} may be specified. + function: A `_DefinedFunction` object. Raises: ValueError: if another function is defined with the same name. """ - name = function_def.signature.name - previous_def = self._functions.get(name, None) - if previous_def: - if previous_def != function_def: - raise ValueError("Another function is already defined with that name") - else: - # No need to add again. - return - self._functions[name] = function_def - if grad_function_name is not None and python_grad_func is not None: + name = function.name + previous = self._functions.get(name, None) + if previous: + raise ValueError("Another function is already defined with that name") + # Sanity checks on gradient definition. + if (function.grad_func_name is not None) and ( + function.python_grad_func is not None): raise ValueError("Gradient defined twice for function %s" % name) - if grad_function_name is not None: - self._function_gradient[name] = grad_function_name - if python_grad_func is not None: - self._function_python_gradient[name] = python_grad_func + self._functions[name] = function # Helper functions to create operations. def create_op(self, op_type, inputs, dtypes, diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py index 400ec0277a..4553ab4006 100644 --- a/tensorflow/python/ops/gradients.py +++ b/tensorflow/python/ops/gradients.py @@ -432,8 +432,8 @@ def gradients(ys, has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads) if has_out_grads and (op._id not in stop_ops): if is_func_call: - grad_fn = ops.get_default_graph()._function_python_gradient.get( - op.type, None) + grad_fn = ops.get_default_graph()._get_function( + op.type).python_grad_func # pylint: enable=protected-access else: # A grad_fn must be defined, either as a function or as None diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py index 91e69e8bd1..70e6b46d6a 100644 --- a/tensorflow/python/ops/gradients_test.py +++ b/tensorflow/python/ops/gradients_test.py @@ -376,8 +376,9 @@ class FunctionGradientsTest(test_util.TensorFlowTestCase): grad_func = function.Defun(x=tf.float32, b=tf.float32, g=tf.float32)( self.XSquarePlusBGradient) with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): - _ = self._GetFunc(grad_func=grad_func, + f = self._GetFunc(grad_func=grad_func, python_grad_func=self._PythonGradient) + f.add_to_graph(tf.Graph()) class StopGradientTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index 92700c9414..3f8a65ca1b 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -1520,12 +1520,12 @@ class MetaGraphTest(tf.test.TestCase): def testStrippedOpListNestedFunctions(self): with self.test_session(): # Square two levels deep + @function.Defun(tf.int32) def f0(x): return tf.square(x) - f0 = function.define_function(f0, {"x": tf.int32}) + @function.Defun(tf.int32) def f1(x): - return function.call_function(f0, x) - f1 = function.define_function(f1, {"x": tf.int32}) + return f0(x) # At this point we've defined two functions but haven't called them, so # there should be no used ops. @@ -1534,7 +1534,7 @@ class MetaGraphTest(tf.test.TestCase): self.assertEquals(len(op_list.op), 0) # If we call the function on a constant, there should be two ops - function.call_function(f1, tf.constant(7)) + _ = f1(tf.constant(7)) op_list = tf.contrib.util.stripped_op_list_for_graph( tf.get_default_graph().as_graph_def()) self.assertEquals(["Const", "Square"], [op.name for op in op_list.op]) |