aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/static_analysis/activity.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/static_analysis/activity.py')
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py82
1 files changed, 24 insertions, 58 deletions
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 086eda7574..cc159031ff 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -44,7 +44,6 @@ class Scope(object):
Attributes:
modified: identifiers modified in this scope
- created: identifiers created in this scope
used: identifiers referenced in this scope
"""
@@ -54,7 +53,8 @@ class Scope(object):
Args:
parent: A Scope or None.
isolated: Whether the scope is isolated, that is, whether variables
- created in this scope should be visible to the parent scope.
+ modified in this scope should be considered modified in the parent
+ scope.
add_unknown_symbols: Whether to handle attributed and subscripts
without having first seen the base name.
E.g., analyzing the statement 'x.y = z' without first having seen 'x'.
@@ -63,13 +63,11 @@ class Scope(object):
self.parent = parent
self.add_unknown_symbols = add_unknown_symbols
self.modified = set()
- # TODO(mdan): Completely remove this.
- self.created = set()
self.used = set()
self.params = {}
self.returned = set()
- # TODO(mdan): Rename to `locals`
+ # TODO(mdan): Rename to `reserved`
@property
def referenced(self):
if not self.isolated and self.parent is not None:
@@ -77,8 +75,7 @@ class Scope(object):
return self.used
def __repr__(self):
- return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
- tuple(self.modified))
+ return 'Scope{r=%s, w=%s}' % (tuple(self.used), tuple(self.modified))
def copy_from(self, other):
"""Recursively copies the contents of this scope from another scope."""
@@ -88,7 +85,6 @@ class Scope(object):
self.parent.copy_from(other.parent)
self.isolated = other.isolated
self.modified = copy.copy(other.modified)
- self.created = copy.copy(other.created)
self.used = copy.copy(other.used)
self.params = copy.copy(other.params)
self.returned = copy.copy(other.returned)
@@ -109,56 +105,28 @@ class Scope(object):
if other.parent is not None:
self.parent.merge_from(other.parent)
self.modified |= other.modified
- self.created |= other.created
self.used |= other.used
self.params.update(other.params)
self.returned |= other.returned
- def has(self, name):
- if name in self.modified:
- return True
- elif self.parent is not None:
- return self.parent.has(name)
- return False
-
def mark_read(self, name):
self.used.add(name)
- if self.parent is not None and name not in self.created:
+ if self.parent is not None and name not in self.params:
self.parent.mark_read(name)
+ def mark_modified(self, name):
+ """Marks the given symbol as modified in the current scope."""
+ self.modified.add(name)
+ if not self.isolated:
+ if self.parent is not None:
+ self.parent.mark_modified(name)
+
def mark_param(self, name, owner):
# Assumption: all AST nodes have the same life span. This lets us use
# a weak reference to mark the connection between a symbol node and the
# function node whose argument that symbol is.
self.params[name] = weakref.ref(owner)
- def mark_creation(self, name, writes_create_symbol=False):
- """Mark a qualified name as created."""
- if name.is_composite():
- parent = name.parent
- if not writes_create_symbol:
- return
- else:
- if not self.has(parent):
- if self.add_unknown_symbols:
- self.mark_read(parent)
- else:
- raise ValueError('Unknown symbol "%s".' % parent)
- self.created.add(name)
-
- def mark_write(self, name):
- """Marks the given symbol as modified in the current scope."""
- self.modified.add(name)
- if self.isolated:
- self.mark_creation(name)
- else:
- if self.parent is None:
- self.mark_creation(name)
- else:
- if not self.parent.has(name):
- self.mark_creation(name)
- self.parent.mark_write(name)
-
def mark_returned(self, name):
self.returned.add(name)
if not self.isolated and self.parent is not None:
@@ -197,10 +165,7 @@ class ActivityAnalyzer(transformer.Base):
return True
return False
- def _track_symbol(self,
- node,
- composite_writes_alter_parent=False,
- writes_create_symbol=False):
+ def _track_symbol(self, node, composite_writes_alter_parent=False):
# A QN may be missing when we have an attribute (or subscript) on a function
# call. Example: a().b
if not anno.hasanno(node, anno.Basic.QN):
@@ -208,11 +173,9 @@ class ActivityAnalyzer(transformer.Base):
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Store):
- self.scope.mark_write(qn)
+ self.scope.mark_modified(qn)
if qn.is_composite and composite_writes_alter_parent:
- self.scope.mark_write(qn.parent)
- if writes_create_symbol:
- self.scope.mark_creation(qn, writes_create_symbol=True)
+ self.scope.mark_modified(qn.parent)
if self._in_aug_assign:
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Load):
@@ -220,13 +183,11 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Param):
# Param contexts appear in function defs, so they have the meaning of
# defining a variable.
- self.scope.mark_write(qn)
+ self.scope.mark_modified(qn)
self.scope.mark_param(qn, self.enclosing_entities[-1])
else:
raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
- anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
-
if self._in_return_statement:
self.scope.mark_returned(qn)
@@ -243,6 +204,12 @@ class ActivityAnalyzer(transformer.Base):
self._exit_scope()
return node
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
def visit_Expr(self, node):
return self._process_statement(node)
@@ -271,8 +238,7 @@ class ActivityAnalyzer(transformer.Base):
def visit_Attribute(self, node):
node = self.generic_visit(node)
if self._in_constructor and self._node_sets_self_attribute(node):
- self._track_symbol(
- node, composite_writes_alter_parent=True, writes_create_symbol=True)
+ self._track_symbol(node, composite_writes_alter_parent=True)
else:
self._track_symbol(node)
return node
@@ -336,7 +302,7 @@ class ActivityAnalyzer(transformer.Base):
# of its name, along with the usage of any decorator accompany it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
- self.scope.mark_write(qual_names.QN(node.name))
+ self.scope.mark_modified(qual_names.QN(node.name))
anno.setanno(node, anno.Static.SCOPE, self.scope)
self._exit_scope()