aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-25 09:13:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:22:24 -0700
commit7cd7a2e3877641da18182424bc7ea114fd7702ba (patch)
treec48f5d192ea34d19a54713218dd8156c817ca62f /tensorflow/python/autograph
parent32140ae87fd86398ac4fa45cb67bd2f29a93090d (diff)
Account for cases when the live value of a function is not hashable, in the built-in functions converter. Example: d.keys() where d is a dict.
PiperOrigin-RevId: 214448772
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions.py9
-rw-r--r--tensorflow/python/autograph/converters/builtin_functions_test.py16
2 files changed, 20 insertions, 5 deletions
diff --git a/tensorflow/python/autograph/converters/builtin_functions.py b/tensorflow/python/autograph/converters/builtin_functions.py
index b8b268d8ce..583c978395 100644
--- a/tensorflow/python/autograph/converters/builtin_functions.py
+++ b/tensorflow/python/autograph/converters/builtin_functions.py
@@ -48,8 +48,13 @@ class BuiltinFunctionTransformer(converter.Base):
node = self.generic_visit(node)
if anno.hasanno(node.func, 'live_val'):
live_val = anno.getanno(node.func, 'live_val')
- if live_val in py_builtins.SUPPORTED_BUILTINS:
- node = self._convert_builtin(live_val, node.args, as_expression=True)
+ try:
+ if live_val in py_builtins.SUPPORTED_BUILTINS:
+ node = self._convert_builtin(live_val, node.args, as_expression=True)
+ except TypeError:
+ # Not everything in Python is hashable. If it isn't then it's definitely
+ # not a supported built-in.
+ return node
return node
def visit_Print(self, node):
diff --git a/tensorflow/python/autograph/converters/builtin_functions_test.py b/tensorflow/python/autograph/converters/builtin_functions_test.py
index c87c304cdb..2ed14c14e7 100644
--- a/tensorflow/python/autograph/converters/builtin_functions_test.py
+++ b/tensorflow/python/autograph/converters/builtin_functions_test.py
@@ -36,7 +36,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return len(a)
with self.converted(test_fn, builtin_functions, {'len': len}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
p = array_ops.placeholder(dtype=dtypes.int32, shape=None)
ops = result.test_fn(p)
self.assertEqual(sess.run(ops, {p: [0, 0, 0]}), 3)
@@ -50,7 +50,7 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a\n'):
sess.run(result.test_fn('a'))
@@ -63,12 +63,22 @@ class BuiltinFunctionsTest(converter_testing.TestCase):
return print(a, b, c)
with self.converted(test_fn, builtin_functions, {'print': print}) as result:
- with self.cached_session() as sess:
+ with self.test_session() as sess:
with self.assertPrints('a 1 [2, 3]\n'):
sess.run(
result.test_fn(
constant_op.constant('a'), constant_op.constant(1), [2, 3]))
+ def test_conversion_robust_to_unhashable_callables(self):
+
+ def test_fn():
+ return foo() # pylint:disable=undefined-variable
+
+ with self.converted(test_fn, builtin_functions, {'foo': {
+ 'a': 'b'
+ }.keys}) as result:
+ self.assertListEqual(list(result.test_fn()), ['a'])
+
if __name__ == '__main__':
test.main()