diff options
author | Shivani Agrawal <shivaniagrawal@google.com> | 2018-10-08 14:44:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 14:59:02 -0700 |
commit | b3bd7b378d00190fef831092836a5df62e39e7ed (patch) | |
tree | efeccf7199fcb76a6fae250b355efd06ebc1fbfb /tensorflow/python/eager | |
parent | 410a83a532a6ffe0be3ad65cd0c84ca77c47f2c5 (diff) |
Ignore args and kwargs for defun's get_concrete_fn if `PolymorphicFunction` was created
with an input_signature.
PiperOrigin-RevId: 216253122
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/function.py | 14 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 9 |
2 files changed, 18 insertions, 5 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 99bf375ea7..ff138cad1e 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -665,6 +665,11 @@ class Function(object): return self._build_call_outputs(outputs) @property + def name(self): + """Function name.""" + return self._inference_function.name + + @property def graph(self): """Returns the graph from which this function was constructed.""" return self._func_graph @@ -721,6 +726,10 @@ class Function(object): return nest.map_structure(lambda x: x.dtype if x is not None else None, self._func_graph.structured_outputs) + def add_to_graph(self, g): + """Adds this function into the graph g.""" + return self._inference_function.add_to_graph(g) + def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" backwards_graph = FuncGraph(_backward_name(self._func_graph.name)) @@ -1133,6 +1142,8 @@ class PolymorphicFunction(object): *args: inputs to specialize on. **kwargs: inputs to specialize on. """ + if self._input_signature: + args, kwargs = None, None graph_function, _ = self._maybe_define_function(args, kwargs) return graph_function @@ -1322,6 +1333,9 @@ def register(func, *args, **kwargs): function definition into graph. Register function with different input param will result into multiple version of functions registered in graph. + Also, `args` and `kwargs` are ignored if this `PolymorphicFunction` was + created with an `input_signature`. + Args: func: the PolymorphicFunction instance that generated by a @defun *args: input arguments for the Python function. diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e46bde098b..953f4300cf 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1841,11 +1841,10 @@ class FunctionTest(test.TestCase): # pylint: disable=protected-access self.assertEqual(len(graph._functions), 3) - # Test input param shape mismatch - t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - with self.assertRaisesRegexp( - ValueError, 'Python inputs incompatible with input_signature'): - function.register(defun_matmul, t2, t2) + # Test register function with cache, note inputs are ignored. + function.register(defun_matmul) + graph = ops.get_default_graph() + self.assertEqual(len(graph._functions), 3) def testRegisterFunctionWithCache(self): def matmul(x, y): |