From b3bd7b378d00190fef831092836a5df62e39e7ed Mon Sep 17 00:00:00 2001 From: Shivani Agrawal Date: Mon, 8 Oct 2018 14:44:37 -0700 Subject: Ignore args and kwargs for defun's get_concrete_fn if `PolymorphicFunction` was created with an input_signature. PiperOrigin-RevId: 216253122 --- tensorflow/python/eager/function.py | 14 ++++++++++++++ tensorflow/python/eager/function_test.py | 9 ++++----- 2 files changed, 18 insertions(+), 5 deletions(-) (limited to 'tensorflow/python') diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 99bf375ea7..ff138cad1e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -664,6 +664,11 @@ class Function(object): return self._build_call_outputs(outputs) + @property + def name(self): + """Function name.""" + return self._inference_function.name + @property def graph(self): """Returns the graph from which this function was constructed.""" @@ -721,6 +726,10 @@ class Function(object): return nest.map_structure(lambda x: x.dtype if x is not None else None, self._func_graph.structured_outputs) + def add_to_graph(self, g): + """Adds this function into the graph g.""" + return self._inference_function.add_to_graph(g) + def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) @@ -1133,6 +1142,8 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ + if self._input_signature: + args, kwargs = None, None graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function @@ -1322,6 +1333,9 @@ def register(func, *args, **kwargs): function definition into graph. Register function with different input param will result into multiple version of functions registered in graph. + Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was + created with an `input_signature`. + Args: func: the PolymorphicFunction instance that generated by a @defun *args: input arguments for the Python function. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e46bde098b..953f4300cf 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1841,11 +1841,10 @@ class FunctionTest(test.TestCase): # pylint: disable=protected-access self.assertEqual(len(graph._functions), 3) - # Test input param shape mismatch - t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - with self.assertRaisesRegexp( - ValueError, 'Python inputs incompatible with input_signature'): - function.register(defun_matmul, t2, t2) + # Test register function with cache, note inputs are ignored. + function.register(defun_matmul) + graph = ops.get_default_graph() + self.assertEqual(len(graph._functions), 3) def testRegisterFunctionWithCache(self): def matmul(x, y): -- cgit v1.2.3