diff options
author | Dan Moldovan <mdan@google.com> | 2018-09-25 09:13:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 09:22:24 -0700 |
commit | 7cd7a2e3877641da18182424bc7ea114fd7702ba (patch) | |
tree | c48f5d192ea34d19a54713218dd8156c817ca62f /tensorflow/python/autograph | |
parent | 32140ae87fd86398ac4fa45cb67bd2f29a93090d (diff) |
Account for cases when the live value of a function is not hashable, in the built-in functions converter. Example: d.keys() where d is a dict.
PiperOrigin-RevId: 214448772
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r-- | tensorflow/python/autograph/converters/builtin_functions.py | 9 | ||||
-rw-r--r-- | tensorflow/python/autograph/converters/builtin_functions_test.py | 16 |
2 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py index b8b268d8ce..583c978395 100644 --- a/tensorflow/python/autograph/converters/builtin_functions.py +++ b/tensorflow/python/autograph/converters/builtin_functions.py @@ -48,8 +48,13 @@ class BuiltinFunctionTransformer(converter.Base): node = self.generic_visit(node) if anno.hasanno(node.func, 'live_val'): live_val = anno.getanno(node.func, 'live_val') - if live_val in py_builtins.SUPPORTED_BUILTINS: - node = self._convert_builtin(live_val, node.args, as_expression=True) + try: + if live_val in py_builtins.SUPPORTED_BUILTINS: + node = self._convert_builtin(live_val, node.args, as_expression=True) + except TypeError: + # Not everything in Python is hashable. If it isn't then it's definitely + # not a supported built-in. + return node return node def visit_Print(self, node): diff --git a/tensorflow/python/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py index c87c304cdb..2ed14c14e7 100644 --- a/tensorflow/python/autograph/converters/builtin_functions_test.py +++ b/tensorflow/python/autograph/converters/builtin_functions_test.py @@ -36,7 +36,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase): return len(a) with self.converted(test_fn, builtin_functions, {'len': len}) as result: - with self.cached_session() as sess: + with self.test_session() as sess: p = array_ops.placeholder(dtype=dtypes.int32, shape=None) ops = result.test_fn(p) self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3) @@ -50,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase): return print(a) with self.converted(test_fn, builtin_functions, {'print': print}) as result: - with self.cached_session() as sess: + with self.test_session() as sess: with self.assertPrints('a\n'): sess.run(result.test_fn('a')) @@ -63,12 +63,22 @@ class BuiltinFunctionsTest(converter_testing.TestCase): return print(a, b, c) with self.converted(test_fn, builtin_functions, {'print': print}) as result: - with self.cached_session() as sess: + with self.test_session() as sess: with self.assertPrints('a 1 [2, 3]\n'): sess.run( result.test_fn( constant_op.constant('a'), constant_op.constant(1), [2, 3])) + def test_conversion_robust_to_unhashable_callables(self): + + def test_fn(): + return foo() # pylint:disable=undefined-variable + + with self.converted(test_fn, builtin_functions, {'foo': { + 'a': 'b' + }.keys}) as result: + self.assertListEqual(list(result.test_fn()), ['a']) + if __name__ == '__main__': test.main() |