diff options
4 files changed, 89 insertions, 9 deletions
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index ec733ea38f..fdfcd4dcc1 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -679,10 +679,6 @@ class AstToCfg(gast.NodeVisitor): self.cfgs[node] = self.builder.build() self.builder = self.builder_stack.pop() - def visit_Lambda(self, node): - # TODO(mdan): Treat like FunctionDef? That would be a separate CFG. - raise NotImplementedError() - def visit_Return(self, node): self._process_exit_statement(node, gast.FunctionDef) diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py index bd82e70f7d..d5870124bc 100644 --- a/tensorflow/python/autograph/pyct/cfg_test.py +++ b/tensorflow/python/autograph/pyct/cfg_test.py @@ -964,6 +964,22 @@ class AstToCfgTest(test.TestCase): ), ) + def test_lambda_basic(self): + + def test_fn(a): + a = lambda b: a + b + return a + + graph, = self._build_cfg(test_fn).values() + + self.assertGraphMatches( + graph, + ( + ('a', 'a = lambda b: a + b', 'return a'), + ('a = lambda b: a + b', 'return a', None), + ), + ) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index cc159031ff..0ce410d522 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -146,8 +146,15 @@ class ActivityAnalyzer(transformer.Base): def __init__(self, context, parent_scope=None, add_unknown_symbols=False): super(ActivityAnalyzer, self).__init__(context) self.scope = Scope(parent_scope, None, add_unknown_symbols) + + # Note: all these flags crucially rely on the respective nodes are + # leaves in the AST, that is, they cannot contain other statements. self._in_return_statement = False self._in_aug_assign = False + self._in_lambda = False + self._in_function_def_args = False + + self._untracked_symbols = None @property def _in_constructor(self): @@ -172,6 +179,13 @@ class ActivityAnalyzer(transformer.Base): return qn = anno.getanno(node, anno.Basic.QN) + # Ignore any untracked symbols. + if self._untracked_symbols: + if qn in self._untracked_symbols: + return + if qn.owner_set & set(self._untracked_symbols): + return + if isinstance(node.ctx, gast.Store): self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: @@ -181,12 +195,20 @@ class ActivityAnalyzer(transformer.Base): elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): - # Param contexts appear in function defs, so they have the meaning of - # defining a variable. - self.scope.mark_modified(qn) - self.scope.mark_param(qn, self.enclosing_entities[-1]) + if self._in_function_def_args: + # In function defs have the meaning of defining a variable. + self.scope.mark_modified(qn) + self.scope.mark_param(qn, self.enclosing_entities[-1]) + elif self._in_lambda: + assert isinstance(self._untracked_symbols, set) + self._untracked_symbols.add(qn) + else: + # TODO(mdan): Is this case even possible? + raise NotImplementedError( + 'Param "{}" outside a function arguments or lambda.'.format(qn)) else: - raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) + raise ValueError('Unknown context {} for node "{}".'.format( + type(node.ctx), qn)) if self._in_return_statement: self.scope.mark_returned(qn) @@ -294,6 +316,15 @@ class ActivityAnalyzer(transformer.Base): self.scope.merge_from(after_child) return parent + def visit_Lambda(self, node): + assert not self._in_lambda or self._in_function_def_args + self._in_lambda = True + self._untracked_symbols = set() + node = self.generic_visit(node) + self._untracked_symbols = None + self._in_lambda = False + return node + def visit_arguments(self, node): return self._process_statement(node) @@ -308,7 +339,10 @@ class ActivityAnalyzer(transformer.Base): # A separate Scope tracks the actual function definition. self._enter_scope(True) + assert not self._in_function_def_args + self._in_function_def_args = True node.args = self.visit(node.args) + self._in_function_def_args = False # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py index 9a4f1bf09b..678199970c 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py @@ -427,6 +427,40 @@ class ActivityAnalyzerTest(test.TestCase): args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE) self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params') + def test_lambda_captures_reads(self): + + def test_fn(a, b): + return lambda: a + b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + self.assertScopeIs(body_scope, ('a', 'b'), ()) + # Nothing local to the lambda is tracked. + self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') + + def test_lambda_params_are_isolated(self): + + def test_fn(a, b): # pylint: disable=unused-argument + return lambda a: a + b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + self.assertScopeIs(body_scope, ('b',), ()) + self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') + + def test_lambda_complex(self): + + def test_fn(a, b, c, d): # pylint: disable=unused-argument + a = (lambda a, b, c: a + b + c)(d, 1, 2) + b + + node, _ = self._parse_and_analyze(test_fn) + fn_node = node.body[0] + body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE) + self.assertScopeIs(body_scope, ('b', 'd'), ('a',)) + self.assertSymbolSetsAre((), body_scope.params.keys(), 'params') + if __name__ == '__main__': test.main() |