aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py4
-rw-r--r--tensorflow/python/autograph/pyct/cfg_test.py16
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py44
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py34
4 files changed, 89 insertions, 9 deletions
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index ec733ea38f..fdfcd4dcc1 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -679,10 +679,6 @@ class AstToCfg(gast.NodeVisitor):
self.cfgs[node] = self.builder.build()
self.builder = self.builder_stack.pop()
- def visit_Lambda(self, node):
- # TODO(mdan): Treat like FunctionDef? That would be a separate CFG.
- raise NotImplementedError()
-
def visit_Return(self, node):
self._process_exit_statement(node, gast.FunctionDef)
diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
index bd82e70f7d..d5870124bc 100644
--- a/tensorflow/python/autograph/pyct/cfg_test.py
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -964,6 +964,22 @@ class AstToCfgTest(test.TestCase):
),
)
+ def test_lambda_basic(self):
+
+ def test_fn(a):
+ a = lambda b: a + b
+ return a
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'a = lambda b: a + b', 'return a'),
+ ('a = lambda b: a + b', 'return a', None),
+ ),
+ )
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index cc159031ff..0ce410d522 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -146,8 +146,15 @@ class ActivityAnalyzer(transformer.Base):
def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
super(ActivityAnalyzer, self).__init__(context)
self.scope = Scope(parent_scope, None, add_unknown_symbols)
+
+ # Note: all these flags crucially rely on the respective nodes are
+ # leaves in the AST, that is, they cannot contain other statements.
self._in_return_statement = False
self._in_aug_assign = False
+ self._in_lambda = False
+ self._in_function_def_args = False
+
+ self._untracked_symbols = None
@property
def _in_constructor(self):
@@ -172,6 +179,13 @@ class ActivityAnalyzer(transformer.Base):
return
qn = anno.getanno(node, anno.Basic.QN)
+ # Ignore any untracked symbols.
+ if self._untracked_symbols:
+ if qn in self._untracked_symbols:
+ return
+ if qn.owner_set & set(self._untracked_symbols):
+ return
+
if isinstance(node.ctx, gast.Store):
self.scope.mark_modified(qn)
if qn.is_composite and composite_writes_alter_parent:
@@ -181,12 +195,20 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
- # Param contexts appear in function defs, so they have the meaning of
- # defining a variable.
- self.scope.mark_modified(qn)
- self.scope.mark_param(qn, self.enclosing_entities[-1])
+ if self._in_function_def_args:
+ # In function defs have the meaning of defining a variable.
+ self.scope.mark_modified(qn)
+ self.scope.mark_param(qn, self.enclosing_entities[-1])
+ elif self._in_lambda:
+ assert isinstance(self._untracked_symbols, set)
+ self._untracked_symbols.add(qn)
+ else:
+ # TODO(mdan): Is this case even possible?
+ raise NotImplementedError(
+ 'Param "{}" outside a function arguments or lambda.'.format(qn))
else:
- raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+ raise ValueError('Unknown context {} for node "{}".'.format(
+ type(node.ctx), qn))
if self._in_return_statement:
self.scope.mark_returned(qn)
@@ -294,6 +316,15 @@ class ActivityAnalyzer(transformer.Base):
self.scope.merge_from(after_child)
return parent
+ def visit_Lambda(self, node):
+ assert not self._in_lambda or self._in_function_def_args
+ self._in_lambda = True
+ self._untracked_symbols = set()
+ node = self.generic_visit(node)
+ self._untracked_symbols = None
+ self._in_lambda = False
+ return node
+
def visit_arguments(self, node):
return self._process_statement(node)
@@ -308,7 +339,10 @@ class ActivityAnalyzer(transformer.Base):
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
+ assert not self._in_function_def_args
+ self._in_function_def_args = True
node.args = self.visit(node.args)
+ self._in_function_def_args = False
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index 9a4f1bf09b..678199970c 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -427,6 +427,40 @@ class ActivityAnalyzerTest(test.TestCase):
args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
+ def test_lambda_captures_reads(self):
+
+ def test_fn(a, b):
+ return lambda: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('a', 'b'), ())
+ # Nothing local to the lambda is tracked.
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_params_are_isolated(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return lambda a: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b',), ())
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_complex(self):
+
+ def test_fn(a, b, c, d): # pylint: disable=unused-argument
+ a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
if __name__ == '__main__':
test.main()