diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/static_analysis/activity.py | 82 |
1 files changed, 24 insertions, 58 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index 086eda7574..cc159031ff 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -44,7 +44,6 @@ class Scope(object): Attributes: modified: identifiers modified in this scope - created: identifiers created in this scope used: identifiers referenced in this scope """ @@ -54,7 +53,8 @@ class Scope(object): Args: parent: A Scope or None. isolated: Whether the scope is isolated, that is, whether variables - created in this scope should be visible to the parent scope. + modified in this scope should be considered modified in the parent + scope. add_unknown_symbols: Whether to handle attributed and subscripts without having first seen the base name. E.g., analyzing the statement 'x.y = z' without first having seen 'x'. @@ -63,13 +63,11 @@ class Scope(object): self.parent = parent self.add_unknown_symbols = add_unknown_symbols self.modified = set() - # TODO(mdan): Completely remove this. - self.created = set() self.used = set() self.params = {} self.returned = set() - # TODO(mdan): Rename to `locals` + # TODO(mdan): Rename to `reserved` @property def referenced(self): if not self.isolated and self.parent is not None: @@ -77,8 +75,7 @@ class Scope(object): return self.used def __repr__(self): - return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created), - tuple(self.modified)) + return 'Scope{r=%s, w=%s}' % (tuple(self.used), tuple(self.modified)) def copy_from(self, other): """Recursively copies the contents of this scope from another scope.""" @@ -88,7 +85,6 @@ class Scope(object): self.parent.copy_from(other.parent) self.isolated = other.isolated self.modified = copy.copy(other.modified) - self.created = copy.copy(other.created) self.used = copy.copy(other.used) self.params = copy.copy(other.params) self.returned = copy.copy(other.returned) @@ -109,56 +105,28 @@ class Scope(object): if other.parent is not None: self.parent.merge_from(other.parent) self.modified |= other.modified - self.created |= other.created self.used |= other.used self.params.update(other.params) self.returned |= other.returned - def has(self, name): - if name in self.modified: - return True - elif self.parent is not None: - return self.parent.has(name) - return False - def mark_read(self, name): self.used.add(name) - if self.parent is not None and name not in self.created: + if self.parent is not None and name not in self.params: self.parent.mark_read(name) + def mark_modified(self, name): + """Marks the given symbol as modified in the current scope.""" + self.modified.add(name) + if not self.isolated: + if self.parent is not None: + self.parent.mark_modified(name) + def mark_param(self, name, owner): # Assumption: all AST nodes have the same life span. This lets us use # a weak reference to mark the connection between a symbol node and the # function node whose argument that symbol is. self.params[name] = weakref.ref(owner) - def mark_creation(self, name, writes_create_symbol=False): - """Mark a qualified name as created.""" - if name.is_composite(): - parent = name.parent - if not writes_create_symbol: - return - else: - if not self.has(parent): - if self.add_unknown_symbols: - self.mark_read(parent) - else: - raise ValueError('Unknown symbol "%s".' % parent) - self.created.add(name) - - def mark_write(self, name): - """Marks the given symbol as modified in the current scope.""" - self.modified.add(name) - if self.isolated: - self.mark_creation(name) - else: - if self.parent is None: - self.mark_creation(name) - else: - if not self.parent.has(name): - self.mark_creation(name) - self.parent.mark_write(name) - def mark_returned(self, name): self.returned.add(name) if not self.isolated and self.parent is not None: @@ -197,10 +165,7 @@ class ActivityAnalyzer(transformer.Base): return True return False - def _track_symbol(self, - node, - composite_writes_alter_parent=False, - writes_create_symbol=False): + def _track_symbol(self, node, composite_writes_alter_parent=False): # A QN may be missing when we have an attribute (or subscript) on a function # call. Example: a().b if not anno.hasanno(node, anno.Basic.QN): @@ -208,11 +173,9 @@ class ActivityAnalyzer(transformer.Base): qn = anno.getanno(node, anno.Basic.QN) if isinstance(node.ctx, gast.Store): - self.scope.mark_write(qn) + self.scope.mark_modified(qn) if qn.is_composite and composite_writes_alter_parent: - self.scope.mark_write(qn.parent) - if writes_create_symbol: - self.scope.mark_creation(qn, writes_create_symbol=True) + self.scope.mark_modified(qn.parent) if self._in_aug_assign: self.scope.mark_read(qn) elif isinstance(node.ctx, gast.Load): @@ -220,13 +183,11 @@ class ActivityAnalyzer(transformer.Base): elif isinstance(node.ctx, gast.Param): # Param contexts appear in function defs, so they have the meaning of # defining a variable. - self.scope.mark_write(qn) + self.scope.mark_modified(qn) self.scope.mark_param(qn, self.enclosing_entities[-1]) else: raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn)) - anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn)) - if self._in_return_statement: self.scope.mark_returned(qn) @@ -243,6 +204,12 @@ class ActivityAnalyzer(transformer.Base): self._exit_scope() return node + def visit_nonlocal(self, node): + raise NotImplementedError() + + def visit_global(self, node): + raise NotImplementedError() + def visit_Expr(self, node): return self._process_statement(node) @@ -271,8 +238,7 @@ class ActivityAnalyzer(transformer.Base): def visit_Attribute(self, node): node = self.generic_visit(node) if self._in_constructor and self._node_sets_self_attribute(node): - self._track_symbol( - node, composite_writes_alter_parent=True, writes_create_symbol=True) + self._track_symbol(node, composite_writes_alter_parent=True) else: self._track_symbol(node) return node @@ -336,7 +302,7 @@ class ActivityAnalyzer(transformer.Base): # of its name, along with the usage of any decorator accompany it. self._enter_scope(False) node.decorator_list = self.visit_block(node.decorator_list) - self.scope.mark_write(qual_names.QN(node.name)) + self.scope.mark_modified(qual_names.QN(node.name)) anno.setanno(node, anno.Static.SCOPE, self.scope) self._exit_scope() |