diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/activity.py | 44 |
1 files changed, 39 insertions, 5 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index cc159031ff..0ce410d522 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -146,8 +146,15 @@ class ActivityAnalyzer(transformer.Base): def __init__(self, context, parent_scope=None, add_unknown_symbols=False): super(ActivityAnalyzer, self).__init__(context) self.scope = Scope(parent_scope, None, add_unknown_symbols) + + # Note: all these flags crucially rely on the respective nodes are + # leaves in the AST, that is, they cannot contain other statements. self._in_return_statement = False self._in_aug_assign = False + self._in_lambda = False + self._in_function_def_args = False + + self._untracked_symbols = None @property def _in_constructor(self): @@ -172,6 +179,13 @@ class ActivityAnalyzer(transformer.Base): return qn = anno.getanno(node, anno.Basic.QN) + # Ignore any untracked symbols. + if self._untracked_symbols: + if qn in self._untracked_symbols: + return + if qn.owner_set & set(self._untracked_symbols): + return + if isinstance(node.ctx, gast.Store): self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: @@ -181,12 +195,20 @@ class ActivityAnalyzer(transformer.Base): elif isinstance(node.ctx, gast.Load): self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Param): - # Param contexts appear in function defs, so they have the meaning of - # defining a variable. - self.scope.mark_modified(qn) - self.scope.mark_param(qn, self.enclosing_entities[-1]) + if self._in_function_def_args: + # In function defs have the meaning of defining a variable. + self.scope.mark_modified(qn) + self.scope.mark_param(qn, self.enclosing_entities[-1]) + elif self._in_lambda: + assert isinstance(self._untracked_symbols, set) + self._untracked_symbols.add(qn) + else: + # TODO(mdan): Is this case even possible? + raise NotImplementedError( + 'Param "{}" outside a function arguments or lambda.'.format(qn)) else: - raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) + raise ValueError('Unknown context {} for node "{}".'.format( + type(node.ctx), qn)) if self._in_return_statement: self.scope.mark_returned(qn) @@ -294,6 +316,15 @@ class ActivityAnalyzer(transformer.Base): self.scope.merge_from(after_child) return parent + def visit_Lambda(self, node): + assert not self._in_lambda or self._in_function_def_args + self._in_lambda = True + self._untracked_symbols = set() + node = self.generic_visit(node) + self._untracked_symbols = None + self._in_lambda = False + return node + def visit_arguments(self, node): return self._process_statement(node) @@ -308,7 +339,10 @@ class ActivityAnalyzer(transformer.Base): # A separate Scope tracks the actual function definition. self._enter_scope(True) + assert not self._in_function_def_args + self._in_function_def_args = True node.args = self.visit(node.args) + self._in_function_def_args = False # Track the body separately. This is for compatibility reasons, it may not # be strictly needed. |