aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-09 17:59:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 18:03:16 -0700
commit2db20be49c660a0c475cb57fe0935791d66433ed (patch)
tree4eaf94b9acee144defe5b4e0542a6f89ee849889
parenta8cc3cbdeb1563c05d75043c9901135f8b9be65a (diff)
Enable support for lambda functions in static analyses.
The CFG treats lambdas as ordinary expressions. The activity analysis ensures that variables masked by the lambda's arguments are not being tracked. Note: lambdas do not allow direct modification (we exclude indirect mutation via function or methods). PiperOrigin-RevId: 216456682
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py4
-rw-r--r--tensorflow/python/autograph/pyct/cfg_test.py16
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py44
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py34
4 files changed, 89 insertions, 9 deletions
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index ec733ea38f..fdfcd4dcc1 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -679,10 +679,6 @@ class AstToCfg(gast.NodeVisitor):
self.cfgs[node] = self.builder.build()
self.builder = self.builder_stack.pop()
- def visit_Lambda(self, node):
- # TODO(mdan): Treat like FunctionDef? That would be a separate CFG.
- raise NotImplementedError()
-
def visit_Return(self, node):
self._process_exit_statement(node, gast.FunctionDef)
diff --git a/tensorflow/python/autograph/pyct/cfg_test.py b/tensorflow/python/autograph/pyct/cfg_test.py
index bd82e70f7d..d5870124bc 100644
--- a/tensorflow/python/autograph/pyct/cfg_test.py
+++ b/tensorflow/python/autograph/pyct/cfg_test.py
@@ -964,6 +964,22 @@ class AstToCfgTest(test.TestCase):
),
)
+ def test_lambda_basic(self):
+
+ def test_fn(a):
+ a = lambda b: a + b
+ return a
+
+ graph, = self._build_cfg(test_fn).values()
+
+ self.assertGraphMatches(
+ graph,
+ (
+ ('a', 'a = lambda b: a + b', 'return a'),
+ ('a = lambda b: a + b', 'return a', None),
+ ),
+ )
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index cc159031ff..0ce410d522 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -146,8 +146,15 @@ class ActivityAnalyzer(transformer.Base):
def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
super(ActivityAnalyzer, self).__init__(context)
self.scope = Scope(parent_scope, None, add_unknown_symbols)
+
+ # Note: all these flags crucially rely on the respective nodes are
+ # leaves in the AST, that is, they cannot contain other statements.
self._in_return_statement = False
self._in_aug_assign = False
+ self._in_lambda = False
+ self._in_function_def_args = False
+
+ self._untracked_symbols = None
@property
def _in_constructor(self):
@@ -172,6 +179,13 @@ class ActivityAnalyzer(transformer.Base):
return
qn = anno.getanno(node, anno.Basic.QN)
+ # Ignore any untracked symbols.
+ if self._untracked_symbols:
+ if qn in self._untracked_symbols:
+ return
+ if qn.owner_set & set(self._untracked_symbols):
+ return
+
if isinstance(node.ctx, gast.Store):
self.scope.mark_modified(qn)
if qn.is_composite and composite_writes_alter_parent:
@@ -181,12 +195,20 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
- # Param contexts appear in function defs, so they have the meaning of
- # defining a variable.
- self.scope.mark_modified(qn)
- self.scope.mark_param(qn, self.enclosing_entities[-1])
+ if self._in_function_def_args:
+ # In function defs have the meaning of defining a variable.
+ self.scope.mark_modified(qn)
+ self.scope.mark_param(qn, self.enclosing_entities[-1])
+ elif self._in_lambda:
+ assert isinstance(self._untracked_symbols, set)
+ self._untracked_symbols.add(qn)
+ else:
+ # TODO(mdan): Is this case even possible?
+ raise NotImplementedError(
+ 'Param "{}" outside a function arguments or lambda.'.format(qn))
else:
- raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+ raise ValueError('Unknown context {} for node "{}".'.format(
+ type(node.ctx), qn))
if self._in_return_statement:
self.scope.mark_returned(qn)
@@ -294,6 +316,15 @@ class ActivityAnalyzer(transformer.Base):
self.scope.merge_from(after_child)
return parent
+ def visit_Lambda(self, node):
+ assert not self._in_lambda or self._in_function_def_args
+ self._in_lambda = True
+ self._untracked_symbols = set()
+ node = self.generic_visit(node)
+ self._untracked_symbols = None
+ self._in_lambda = False
+ return node
+
def visit_arguments(self, node):
return self._process_statement(node)
@@ -308,7 +339,10 @@ class ActivityAnalyzer(transformer.Base):
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
+ assert not self._in_function_def_args
+ self._in_function_def_args = True
node.args = self.visit(node.args)
+ self._in_function_def_args = False
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index 9a4f1bf09b..678199970c 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -427,6 +427,40 @@ class ActivityAnalyzerTest(test.TestCase):
args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
+ def test_lambda_captures_reads(self):
+
+ def test_fn(a, b):
+ return lambda: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('a', 'b'), ())
+ # Nothing local to the lambda is tracked.
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_params_are_isolated(self):
+
+ def test_fn(a, b): # pylint: disable=unused-argument
+ return lambda a: a + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b',), ())
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
+ def test_lambda_complex(self):
+
+ def test_fn(a, b, c, d): # pylint: disable=unused-argument
+ a = (lambda a, b, c: a + b + c)(d, 1, 2) + b
+
+ node, _ = self._parse_and_analyze(test_fn)
+ fn_node = node.body[0]
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b', 'd'), ('a',))
+ self.assertSymbolSetsAre((), body_scope.params.keys(), 'params')
+
if __name__ == '__main__':
test.main()