diff options
author | Scott Zhu <scottzhu@google.com> | 2018-09-13 13:54:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 13:58:39 -0700 |
commit | df46916ab0f8aa9fbf45f6847c9216ecc90515a9 (patch) | |
tree | 20771b1fbf41429d7caf22c9c28cebfbf62f07a7 /tensorflow/python/eager | |
parent | 2646bf2d2bfb717c828db6391563b431f760a7d3 (diff) |
Allow user to the pre register a defun function into graph without calling it.
PiperOrigin-RevId: 212872452
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 28 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 78 |
2 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 348bf4650f..552ed29f65 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1204,6 +1204,34 @@ class PolymorphicFunction(object): return graph_function, (args, kwds) +def register(func, *args, **kwargs): + """Register the defun function into the graph. + + This won't actually call the function with the inputs, and only put the + function definition into graph. Register function with different input param + will result into multiple version of functions registered in graph. + + Args: + func: the PolymorphicFunction instance that generated by a @defun + *args: input arguments for the Python function. + **kwargs: input keyword arguments for the Python function. + + Returns: + a `Function` object specialized to inputs and execution context. + + Raises: + ValueError: When the input function is not a defun wrapped python function. + """ + if not isinstance(func, PolymorphicFunction): + raise ValueError("Only defun function is allowed to be registered. " + "Got type: %s" % type(func)) + concrete_func = func.get_concrete_function(*args, **kwargs) + graph = ops.get_default_graph() + concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access + # TODO(scottzhu): support concrete_func._backward_graph_function in future. + return concrete_func + + def _validate_signature(signature): if any(not isinstance(arg, tensor_spec.TensorSpec) for arg in nest.flatten(signature)): diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index d2b1d9c8a7..a0abefe666 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1607,6 +1607,84 @@ class FunctionTest(test.TestCase): t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) add(t, t) + def testRegisterFunction(self): + @function.defun + def add(x, y): + return math_ops.add(x, y) + + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + function.register(add, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + pre_register_matmul_func_name = functions[0].definition.signature.name + self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*') + pre_register_add_func_name = functions[1].definition.signature.name + self.assertRegexpMatches(pre_register_add_func_name, '.*add.*') + + sq = defun_matmul(t, t) + double = add(t, t) + self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22]) + self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8]) + # Make sure the pre registered function is used, and no other function + # is added. + self.assertEqual(len(graph._functions), 2) + functions = list(graph._functions.values()) + called_func_name = functions[0].definition.signature.name + self.assertEqual(pre_register_matmul_func_name, called_func_name) + called_func_name = functions[1].definition.signature.name + self.assertEqual(pre_register_add_func_name, called_func_name) + + def testRegisterFunctionWithInputSignature(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun( + matmul, + input_signature=[ + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32), + tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32) + ]) + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + function.register(defun_matmul, t, t) + + graph = ops.get_default_graph() + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + + # Test input param shape mismatch + t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + with self.assertRaisesRegexp( + ValueError, 'Python inputs incompatible with input_signature'): + function.register(defun_matmul, t2, t2) + + def testRegisterFunctionWithCache(self): + def matmul(x, y): + return math_ops.matmul(x, y) + defun_matmul = function.defun(matmul) + + with context.graph_mode(), self.cached_session(): + with ops.get_default_graph().as_default(): + t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]]) + t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]]) + function.register(defun_matmul, t, t) + function.register(defun_matmul, t2, t2) + + graph = ops.get_default_graph() + # Only one function is registered since the input param are in same type + # pylint: disable=protected-access + self.assertEqual(len(graph._functions), 1) + @test_util.with_c_shapes class AutomaticControlDependenciesTest(test.TestCase): |