diff options
author | 2018-07-26 10:01:50 -0700 | |
---|---|---|
committer | 2018-07-26 10:05:09 -0700 | |
commit | 883e8b6863511b46d7985e9ff8d1809ffe2a1bc0 (patch) | |
tree | 242a7257bfd31687720b7e08a5c7cb833d6c3b79 /tensorflow/python/util | |
parent | 86cffb1d9201f8072cef3eb13ef0dc524e0f4535 (diff) |
Make function_utils.get_func_code more tolerant of strange objects like functool.partial.
PiperOrigin-RevId: 206175973
Diffstat (limited to 'tensorflow/python/util')
-rw-r--r-- | tensorflow/python/util/function_utils.py | 14 | ||||
-rw-r--r-- | tensorflow/python/util/function_utils_test.py | 78 |
2 files changed, 87 insertions, 5 deletions
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py index 61312feafd..4e9b07e20a 100644 --- a/tensorflow/python/util/function_utils.py +++ b/tensorflow/python/util/function_utils.py @@ -69,20 +69,24 @@ def get_func_name(func): return '%s.%s' % (six.get_method_self(func).__class__.__name__, six.get_method_function(func).__name__) else: # Probably a class instance with __call__ - return type(func) + return str(type(func)) else: raise ValueError('Argument must be callable') def get_func_code(func): - """Returns func_code of passed callable.""" + """Returns func_code of passed callable, or None if not available.""" _, func = tf_decorator.unwrap(func) if callable(func): if tf_inspect.isfunction(func) or tf_inspect.ismethod(func): return six.get_function_code(func) - elif hasattr(func, '__call__'): + # Since the object is not a function or method, but is a callable, we will + # try to access the __call__method as a function. This works with callable + # classes but fails with functool.partial objects despite their __call__ + # attribute. + try: return six.get_function_code(func.__call__) - else: - raise ValueError('Unhandled callable, type=%s' % type(func)) + except AttributeError: + return None else: raise ValueError('Argument must be callable') diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py index e78cf6a5b0..1588328c26 100644 --- a/tensorflow/python/util/function_utils_test.py +++ b/tensorflow/python/util/function_utils_test.py @@ -24,6 +24,16 @@ from tensorflow.python.platform import test from tensorflow.python.util import function_utils +def silly_example_function(): + pass + + +class SillyCallableClass(object): + + def __call__(self): + pass + + class FnArgsTest(test.TestCase): def test_simple_function(self): @@ -124,5 +134,73 @@ class FnArgsTest(test.TestCase): self.assertEqual(3, double_wrapped_fn(3)) self.assertEqual(3, double_wrapped_fn(a=3)) + +class GetFuncNameTest(test.TestCase): + + def testWithSimpleFunction(self): + self.assertEqual( + 'silly_example_function', + function_utils.get_func_name(silly_example_function)) + + def testWithClassMethod(self): + self.assertEqual( + 'GetFuncNameTest.testWithClassMethod', + function_utils.get_func_name(self.testWithClassMethod)) + + def testWithCallableClass(self): + callable_instance = SillyCallableClass() + self.assertRegexpMatches( + function_utils.get_func_name(callable_instance), + '<.*SillyCallableClass.*>') + + def testWithFunctoolsPartial(self): + partial = functools.partial(silly_example_function) + self.assertRegexpMatches( + function_utils.get_func_name(partial), + '<.*functools.partial.*>') + + def testWithLambda(self): + anon_fn = lambda x: x + self.assertEqual('<lambda>', function_utils.get_func_name(anon_fn)) + + def testRaisesWithNonCallableObject(self): + with self.assertRaises(ValueError): + function_utils.get_func_name(None) + + +class GetFuncCodeTest(test.TestCase): + + def testWithSimpleFunction(self): + code = function_utils.get_func_code(silly_example_function) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithClassMethod(self): + code = function_utils.get_func_code(self.testWithClassMethod) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithCallableClass(self): + callable_instance = SillyCallableClass() + code = function_utils.get_func_code(callable_instance) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithLambda(self): + anon_fn = lambda x: x + code = function_utils.get_func_code(anon_fn) + self.assertIsNotNone(code) + self.assertRegexpMatches(code.co_filename, 'function_utils_test.py') + + def testWithFunctoolsPartial(self): + partial = functools.partial(silly_example_function) + code = function_utils.get_func_code(partial) + self.assertIsNone(code) + + def testRaisesWithNonCallableObject(self): + with self.assertRaises(ValueError): + function_utils.get_func_code(None) + + if __name__ == '__main__': test.main() |