aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity_test.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index 9a4f1bf09b..678199970c 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -427,6 +427,40 @@ class ActivityAnalyzerTest(test.TestCase):
args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
+ def test_lambda_captures_reads(self):
+
+ def test_fn(a, b):
+ return lambda: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('a', 'b'), ())
+ # Nothing local to the lambda is tracked.
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_params_are_isolated(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return lambda a: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b',), ())
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_complex(self):
+
+ def test_fn(a, b, c, d): # pylint: disable=unused-argument
+ a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
if __name__ == '__main__':
test.main()