aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/static_analysis/activity.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity.py226
1 files changed, 105 insertions, 121 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
index 4d7b0cbb7b..a0182da9d1 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
@@ -12,7 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Activity analysis."""
+"""Activity analysis.
+
+Requires qualified name annotations (see qual_names.py).
+"""
from __future__ import absolute_import
from __future__ import division
@@ -59,9 +62,10 @@ class Scope(object):
self.parent = parent
self.add_unknown_symbols = add_unknown_symbols
self.modified = set()
+ # TODO(mdan): Completely remove this.
self.created = set()
self.used = set()
- self.params = set()
+ self.params = {}
self.returned = set()
# TODO(mdan): Rename to `locals`
@@ -106,37 +110,23 @@ class Scope(object):
self.modified |= other.modified
self.created |= other.created
self.used |= other.used
- self.params |= other.params
+ self.params.update(other.params)
self.returned |= other.returned
def has(self, name):
- if name in self.modified or name in self.params:
+ if name in self.modified:
return True
elif self.parent is not None:
return self.parent.has(name)
return False
- def is_modified_since_entry(self, name):
- if name in self.modified:
- return True
- elif self.parent is not None and not self.isolated:
- return self.parent.is_modified_since_entry(name)
- return False
-
- def is_param(self, name):
- if name in self.params:
- return True
- elif self.parent is not None and not self.isolated:
- return self.parent.is_param(name)
- return False
-
def mark_read(self, name):
self.used.add(name)
if self.parent is not None and name not in self.created:
self.parent.mark_read(name)
- def mark_param(self, name):
- self.params.add(name)
+ def mark_param(self, name, owner):
+ self.params[name] = owner
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
@@ -226,37 +216,56 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Param):
# Param contexts appear in function defs, so they have the meaning of
# defining a variable.
- # TODO(mdan): This may be incorrect with nested functions.
- # For nested functions, we'll have to add the notion of hiding args from
- # the parent scope, not writing to them.
- self.scope.mark_creation(qn)
- self.scope.mark_param(qn)
+ self.scope.mark_write(qn)
+ self.scope.mark_param(qn, self.enclosing_entities[-1])
else:
raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
- anno.setanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY,
- self.scope.is_modified_since_entry(qn))
- anno.setanno(node, NodeAnno.IS_PARAM, self.scope.is_param(qn))
if self._in_return_statement:
self.scope.mark_returned(qn)
+ def _enter_scope(self, isolated):
+ self.scope = Scope(self.scope, isolated=isolated)
+
+ def _exit_scope(self):
+ self.scope = self.scope.parent
+
+ def _process_statement(self, node):
+ self._enter_scope(False)
+ node = self.generic_visit(node)
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+ return node
+
+ def visit_Expr(self, node):
+ return self._process_statement(node)
+
+ def visit_Return(self, node):
+ self._in_return_statement = True
+ node = self._process_statement(node)
+ self._in_return_statement = False
+ return node
+
+ def visit_Assign(self, node):
+ return self._process_statement(node)
+
def visit_AugAssign(self, node):
# Special rules for AugAssign. In Assign, the target is only written,
# but in AugAssig (e.g. a += b), the target is both read and written.
self._in_aug_assign = True
- self.generic_visit(node)
+ node = self._process_statement(node)
self._in_aug_assign = False
return node
def visit_Name(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
self._track_symbol(node)
return node
def visit_Attribute(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
if self._in_constructor and self._node_sets_self_attribute(node):
self._track_symbol(
node, composite_writes_alter_parent=True, writes_create_symbol=True)
@@ -265,44 +274,38 @@ class ActivityAnalyzer(transformer.Base):
return node
def visit_Subscript(self, node):
- self.generic_visit(node)
+ node = self.generic_visit(node)
# Subscript writes (e.g. a[b] = "value") are considered to modify
# both the element itself (a[b]) and its parent (a).
- self._track_symbol(node, composite_writes_alter_parent=True)
+ self._track_symbol(node)
return node
def visit_Print(self, node):
- current_scope = self.scope
- args_scope = Scope(current_scope)
- self.scope = args_scope
- for n in node.values:
- self.visit(n)
- anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
- self.scope = current_scope
+ self._enter_scope(False)
+ node.values = self.visit_block(node.values)
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
+ self._exit_scope()
return node
+ def visit_Assert(self, node):
+ return self._process_statement(node)
+
def visit_Call(self, node):
- current_scope = self.scope
- args_scope = Scope(current_scope, isolated=False)
- self.scope = args_scope
- for n in node.args:
- self.visit(n)
+ self._enter_scope(False)
+ node.args = self.visit_block(node.args)
+ node.keywords = self.visit_block(node.keywords)
# TODO(mdan): Account starargs, kwargs
- for n in node.keywords:
- self.visit(n)
- anno.setanno(node, NodeAnno.ARGS_SCOPE, args_scope)
- self.scope = current_scope
- self.visit(node.func)
+ anno.setanno(node, NodeAnno.ARGS_SCOPE, self.scope)
+ self._exit_scope()
+ node.func = self.visit(node.func)
return node
def _process_block_node(self, node, block, scope_name):
- current_scope = self.scope
- block_scope = Scope(current_scope, isolated=False)
- self.scope = block_scope
- for n in block:
- self.visit(n)
- anno.setanno(node, scope_name, block_scope)
- self.scope = current_scope
+ self._enter_scope(False)
+ block = self.visit_block(block)
+ anno.setanno(node, scope_name, self.scope)
+ self._exit_scope()
return node
def _process_parallel_blocks(self, parent, children):
@@ -321,94 +324,75 @@ class ActivityAnalyzer(transformer.Base):
self.scope.merge_from(after_child)
return parent
+ def visit_arguments(self, node):
+ return self._process_statement(node)
+
def visit_FunctionDef(self, node):
- if self.scope:
- qn = qual_names.QN(node.name)
- self.scope.mark_write(qn)
- current_scope = self.scope
- body_scope = Scope(current_scope, isolated=True)
- self.scope = body_scope
- self.generic_visit(node)
- anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
- self.scope = current_scope
+ # The FunctionDef node itself has a Scope object that tracks the creation
+ # of its name, along with the usage of any decorator accompany it.
+ self._enter_scope(False)
+ node.decorator_list = self.visit_block(node.decorator_list)
+ self.scope.mark_write(qual_names.QN(node.name))
+ anno.setanno(node, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
+
+ # A separate Scope tracks the actual function definition.
+ self._enter_scope(True)
+ node.args = self.visit(node.args)
+
+ # Track the body separately. This is for compatibility reasons, it may not
+ # be strictly needed.
+ self._enter_scope(False)
+ node.body = self.visit_block(node.body)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
+ self._exit_scope()
+
+ self._exit_scope()
return node
def visit_With(self, node):
- current_scope = self.scope
- with_scope = Scope(current_scope, isolated=False)
- self.scope = with_scope
- self.generic_visit(node)
- anno.setanno(node, NodeAnno.BODY_SCOPE, with_scope)
- self.scope = current_scope
+ self._enter_scope(False)
+ node = self.generic_visit(node)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, self.scope)
+ self._exit_scope()
return node
- def visit_If(self, node):
- current_scope = self.scope
- cond_scope = Scope(current_scope, isolated=False)
- self.scope = cond_scope
- self.visit(node.test)
- anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
- self.scope = current_scope
+ def visit_withitem(self, node):
+ return self._process_statement(node)
+ def visit_If(self, node):
+ self._enter_scope(False)
+ node.test = self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
+ anno.setanno(node.test, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_For(self, node):
- self.visit(node.target)
- self.visit(node.iter)
+ self._enter_scope(False)
+ node.target = self.visit(node.target)
+ node.iter = self.visit(node.iter)
+ anno.setanno(node.iter, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
def visit_While(self, node):
- current_scope = self.scope
- cond_scope = Scope(current_scope, isolated=False)
- self.scope = cond_scope
- self.visit(node.test)
- anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
- self.scope = current_scope
-
+ self._enter_scope(False)
+ node.test = self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, self.scope)
+ anno.setanno(node.test, anno.Static.SCOPE, self.scope)
+ self._exit_scope()
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
return node
- def visit_Return(self, node):
- self._in_return_statement = True
- node = self.generic_visit(node)
- self._in_return_statement = False
- return node
-
-
-def get_read(node, context):
- """Return the variable names as QNs (qual_names.py) read by this statement."""
- analyzer = ActivityAnalyzer(context, None, True)
- analyzer.visit(node)
- return analyzer.scope.used
-
-
-def get_updated(node, context):
- """Return the variable names created or mutated by this statement.
-
- This function considers assign statements, augmented assign statements, and
- the targets of for loops, as well as function arguments.
- For example, `x[0] = 2` will return `x`, `x, y = 3, 4` will return `x` and
- `y`, `for i in range(x)` will return `i`, etc.
- Args:
- node: An AST node
- context: An EntityContext instance
-
- Returns:
- A set of variable names (QNs, see qual_names.py) of all the variables
- created or mutated.
- """
- analyzer = ActivityAnalyzer(context, None, True)
- analyzer.visit(node)
- return analyzer.scope.created | analyzer.scope.modified
-
def resolve(node, context, parent_scope=None):
return ActivityAnalyzer(context, parent_scope).visit(node)