diff options
-rw-r--r-- | tensorflow/python/eager/function.py | 10 | ||||
-rw-r--r-- | tensorflow/python/eager/function_test.py | 13 |
2 files changed, 21 insertions, 2 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 54363ffcba..f315fa296c 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -960,11 +960,17 @@ class _PolymorphicFunction(object): self._lock = threading.Lock() fullargspec = tf_inspect.getfullargspec(self._python_function) + if tf_inspect.ismethod(self._python_function): + # Remove `self`: default arguments shouldn't be matched to it. + args = fullargspec.args[1:] + else: + args = fullargspec.args + # A cache mapping from argument name to index, for canonicalizing # arguments that are called in a keyword-like fashion. - self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)} + self._args_to_indices = {arg: i for i, arg in enumerate(args)} # A cache mapping from arg index to default value, for canonicalization. - offset = len(fullargspec.args) - len(fullargspec.defaults or []) + offset = len(args) - len(fullargspec.defaults or []) self._arg_indices_to_default_values = { offset + index: default for index, default in enumerate(fullargspec.defaults or []) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index b9e29635f8..b7c9334c33 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -1233,6 +1233,19 @@ class FunctionTest(test.TestCase): self.assertEqual(one.numpy(), 1.0) self.assertEqual(two.numpy(), 2) + def testDefuningInstanceMethodWithDefaultArgument(self): + + integer = constant_op.constant(2, dtypes.int64) + + class Foo(object): + + @function.defun + def func(self, other=integer): + return other + + foo = Foo() + self.assertEqual(foo.func().numpy(), int(integer)) + def testPythonCallWithSideEffects(self): state = [] |