aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py44
1 files changed, 39 insertions, 5 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index cc159031ff..0ce410d522 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -146,8 +146,15 @@ class ActivityAnalyzer(transformer.Base):
def __init__(self, context, parent_scope=None, add_unknown_symbols=False):
super(ActivityAnalyzer, self).__init__(context)
self.scope = Scope(parent_scope, None, add_unknown_symbols)
+
+ # Note: all these flags crucially rely on the respective nodes are
+ # leaves in the AST, that is, they cannot contain other statements.
self._in_return_statement = False
self._in_aug_assign = False
+ self._in_lambda = False
+ self._in_function_def_args = False
+
+ self._untracked_symbols = None
@property
def _in_constructor(self):
@@ -172,6 +179,13 @@ class ActivityAnalyzer(transformer.Base):
return
qn = anno.getanno(node, anno.Basic.QN)
+ # Ignore any untracked symbols.
+ if self._untracked_symbols:
+ if qn in self._untracked_symbols:
+ return
+ if qn.owner_set & set(self._untracked_symbols):
+ return
+
if isinstance(node.ctx, gast.Store):
self.scope.mark_modified(qn)
if qn.is_composite and composite_writes_alter_parent:
@@ -181,12 +195,20 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
- # Param contexts appear in function defs, so they have the meaning of
- # defining a variable.
- self.scope.mark_modified(qn)
- self.scope.mark_param(qn, self.enclosing_entities[-1])
+ if self._in_function_def_args:
+ # In function defs have the meaning of defining a variable.
+ self.scope.mark_modified(qn)
+ self.scope.mark_param(qn, self.enclosing_entities[-1])
+ elif self._in_lambda:
+ assert isinstance(self._untracked_symbols, set)
+ self._untracked_symbols.add(qn)
+ else:
+ # TODO(mdan): Is this case even possible?
+ raise NotImplementedError(
+ 'Param "{}" outside a function arguments or lambda.'.format(qn))
else:
- raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
+ raise ValueError('Unknown context {} for node "{}".'.format(
+ type(node.ctx), qn))
if self._in_return_statement:
self.scope.mark_returned(qn)
@@ -294,6 +316,15 @@ class ActivityAnalyzer(transformer.Base):
self.scope.merge_from(after_child)
return parent
+ def visit_Lambda(self, node):
+ assert not self._in_lambda or self._in_function_def_args
+ self._in_lambda = True
+ self._untracked_symbols = set()
+ node = self.generic_visit(node)
+ self._untracked_symbols = None
+ self._in_lambda = False
+ return node
+
def visit_arguments(self, node):
return self._process_statement(node)
@@ -308,7 +339,10 @@ class ActivityAnalyzer(transformer.Base):
# A separate Scope tracks the actual function definition.
self._enter_scope(True)
+ assert not self._in_function_def_args
+ self._in_function_def_args = True
node.args = self.visit(node.args)
+ self._in_function_def_args = False
# Track the body separately. This is for compatibility reasons, it may not
# be strictly needed.