From df46916ab0f8aa9fbf45f6847c9216ecc90515a9 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Thu, 13 Sep 2018 13:54:44 -0700 Subject: Allow user to the pre register a defun function into graph without calling it. PiperOrigin-RevId: 212872452 --- tensorflow/python/eager/function.py | 28 ++++++++++++ tensorflow/python/eager/function_test.py | 78 ++++++++++++++++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 348bf4650f..552ed29f65 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1204,6 +1204,34 @@ class PolymorphicFunction(object): return graph_function, (args, kwds) +def register(func, *args, **kwargs): + """Register the defun function into the graph. + + This won't actually call the function with the inputs, and only put the + function definition into graph. Register function with different input param + will result into multiple version of functions registered in graph. + + Args: + func: the PolymorphicFunction instance that generated by a @defun + *args: input arguments for the Python function. + **kwargs: input keyword arguments for the Python function. + + Returns: + a `Function` object specialized to inputs and execution context. + + Raises: + ValueError: When the input function is not a defun wrapped python function. + """ + if not isinstance(func, PolymorphicFunction): + raise ValueError("Only defun function is allowed to be registered. " + "Got type: %s" % type(func)) + concrete_func = func.get_concrete_function(*args, **kwargs) + graph = ops.get_default_graph() + concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access + # TODO(scottzhu): support concrete_func._backward_graph_function in future. + return concrete_func + + def _validate_signature(signature): if any(not isinstance(arg, tensor_spec.TensorSpec) for arg in nest.flatten(signature)): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index d2b1d9c8a7..a0abefe666 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1607,6 +1607,84 @@ class FunctionTest(test.TestCase): t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) add(t, t) + def testRegisterFunction(self): + @function.defun + def add(x, y): + return math_ops.add(x, y) + + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + function.register(add, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + pre_register_matmul_func_name = functions[0].definition.signature.name + self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*') + pre_register_add_func_name = functions[1].definition.signature.name + self.assertRegexpMatches(pre_register_add_func_name, '.*add.*') + + sq = defun_matmul(t, t) + double = add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + # Make sure the pre registered function is used, and no other function + # is added. + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + called_func_name = functions[0].definition.signature.name + self.assertEqual(pre_register_matmul_func_name, called_func_name) + called_func_name = functions[1].definition.signature.name + self.assertEqual(pre_register_add_func_name, called_func_name) + + def testRegisterFunctionWithInputSignature(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun( + matmul, + input_signature=[ + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) + ]) + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + + # Test input param shape mismatch + t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with self.assertRaisesRegexp( + ValueError, 'Python inputs incompatible with input_signature'): + function.register(defun_matmul, t2, t2) + + def testRegisterFunctionWithCache(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) + function.register(defun_matmul, t, t) + function.register(defun_matmul, t2, t2) + + graph = ops.get_default_graph() + # Only one function is registered since the input param are in same type + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): -- cgit v1.2.3