aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Scott Zhu <scottzhu@google.com>2018-09-13 13:54:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 13:58:39 -0700
commitdf46916ab0f8aa9fbf45f6847c9216ecc90515a9 (patch)
tree20771b1fbf41429d7caf22c9c28cebfbf62f07a7 /tensorflow/python/eager
parent2646bf2d2bfb717c828db6391563b431f760a7d3 (diff)
Allow user to the pre register a defun function into graph without calling it.
PiperOrigin-RevId: 212872452
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py28
-rw-r--r--tensorflow/python/eager/function_test.py78
2 files changed, 106 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 348bf4650f..552ed29f65 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1204,6 +1204,34 @@ class PolymorphicFunction(object):
return graph_function, (args, kwds)
+def register(func, *args, **kwargs):
+ """Register the defun function into the graph.
+
+ This won't actually call the function with the inputs, and only put the
+ function definition into graph. Register function with different input param
+ will result into multiple version of functions registered in graph.
+
+ Args:
+ func: the PolymorphicFunction instance that generated by a @defun
+ *args: input arguments for the Python function.
+ **kwargs: input keyword arguments for the Python function.
+
+ Returns:
+ a `Function` object specialized to inputs and execution context.
+
+ Raises:
+ ValueError: When the input function is not a defun wrapped python function.
+ """
+ if not isinstance(func, PolymorphicFunction):
+ raise ValueError("Only defun function is allowed to be registered. "
+ "Got type: %s" % type(func))
+ concrete_func = func.get_concrete_function(*args, **kwargs)
+ graph = ops.get_default_graph()
+ concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
+ # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+ return concrete_func
+
+
def _validate_signature(signature):
if any(not isinstance(arg, tensor_spec.TensorSpec)
for arg in nest.flatten(signature)):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index d2b1d9c8a7..a0abefe666 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1607,6 +1607,84 @@ class FunctionTest(test.TestCase):
t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
add(t, t)
+ def testRegisterFunction(self):
+ @function.defun
+ def add(x, y):
+ return math_ops.add(x, y)
+
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+ function.register(add, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ pre_register_matmul_func_name = functions[0].definition.signature.name
+ self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
+ pre_register_add_func_name = functions[1].definition.signature.name
+ self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+
+ sq = defun_matmul(t, t)
+ double = add(t, t)
+ self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
+ self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
+ # Make sure the pre registered function is used, and no other function
+ # is added.
+ self.assertEqual(len(graph._functions), 2)
+ functions = list(graph._functions.values())
+ called_func_name = functions[0].definition.signature.name
+ self.assertEqual(pre_register_matmul_func_name, called_func_name)
+ called_func_name = functions[1].definition.signature.name
+ self.assertEqual(pre_register_add_func_name, called_func_name)
+
+ def testRegisterFunctionWithInputSignature(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(
+ matmul,
+ input_signature=[
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
+ tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32)
+ ])
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ function.register(defun_matmul, t, t)
+
+ graph = ops.get_default_graph()
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
+ # Test input param shape mismatch
+ t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ with self.assertRaisesRegexp(
+ ValueError, 'Python inputs incompatible with input_signature'):
+ function.register(defun_matmul, t2, t2)
+
+ def testRegisterFunctionWithCache(self):
+ def matmul(x, y):
+ return math_ops.matmul(x, y)
+ defun_matmul = function.defun(matmul)
+
+ with context.graph_mode(), self.cached_session():
+ with ops.get_default_graph().as_default():
+ t = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[2.0, 3.0], [4.0, 5.0]])
+ function.register(defun_matmul, t, t)
+ function.register(defun_matmul, t2, t2)
+
+ graph = ops.get_default_graph()
+ # Only one function is registered since the input param are in same type
+ # pylint: disable=protected-access
+ self.assertEqual(len(graph._functions), 1)
+
@test_util.with_c_shapes
class AutomaticControlDependenciesTest(test.TestCase):