aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/transformer.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/transformer.py173
1 files changed, 168 insertions, 5 deletions
diff --git a/tensorflow/contrib/autograph/pyct/transformer.py b/tensorflow/contrib/autograph/pyct/transformer.py
index 7655811830..969ca12244 100644
--- a/tensorflow/contrib/autograph/pyct/transformer.py
+++ b/tensorflow/contrib/autograph/pyct/transformer.py
@@ -59,6 +59,103 @@ class EntityInfo(object):
self.owner_type = owner_type
+class _StateStack(object):
+ """Typed stack abstraction.
+
+ This class provides syntactic sugar for a stack of objects of known
+ type. It allows accessing attributes of the object at the top of the stack
+ directly against this object, which allows for very terse syntax.
+
+ For example, this code:
+
+ stack = _StateStack(Foo)
+ stack.enter()
+ stack.bar
+
+ Is equivalent to:
+
+ stack = []
+ stack.append(Foo())
+ foo = stack[-1]
+ foo.bar
+
+ See _State for more on how this is used.
+
+ Attributes:
+ type: Any, the type of objects that this stack holds
+ level: int, the current stack depth
+ value: Any, the instance of the object at the top of the stack
+ """
+
+ def __init__(self, type_):
+ # Because we override __setattr__, we need to attach these attributes using
+ # the superclass' setattr.
+ object.__setattr__(self, 'type', type_)
+ object.__setattr__(self, '_stack', [])
+ self.enter()
+
+ def enter(self):
+ self._stack.append(self.type())
+
+ def exit(self):
+ return self._stack.pop()
+
+ @property
+ def level(self):
+ return len(self._stack)
+
+ @property
+ def value(self):
+ return self._stack[-1]
+
+ def __getattr__(self, key):
+ return getattr(self._stack[-1], key)
+
+ def __setattr__(self, key, value):
+ setattr(self._stack[-1], key, value)
+
+
+class _State(object):
+ """Supporting class for nested scope variable space for converter.Base.
+
+ This structure offers syntactic sugar over a dict of stacks of objects
+ of known type. These structures are useful to keep state during AST walks.
+ Multiple different scopes can be tracked in parallel. For example:
+
+ s = _State()
+
+ s[foo].enter()
+ s[bar].enter() # this will not affect s[foo]
+
+ Element access has special semantics:
+ * keys are a data type
+ * element values are _StateStack(type=key) objects
+ * missing elements are automatically added, similarly to defaultdict
+
+ For example, the following block :
+
+ _State s
+ s[Foo]
+
+ Is equivalent to:
+
+ s = {}
+ if Foo not in s:
+ s[Foo] = Foo()
+ s[Foo]
+
+ See Base for how it's used.
+ """
+
+ def __init__(self):
+ self._value = {}
+
+ def __getitem__(self, key):
+ if key not in self._value:
+ self._value[key] = _StateStack(key)
+ return self._value[key]
+
+
class Base(gast.NodeTransformer):
"""Base class for general-purpose code transformers transformers.
@@ -71,6 +168,27 @@ class Base(gast.NodeTransformer):
(possibly nested) scopes, use enter/exit_local_scope and set/get_local.
You must call enter/exit_local_scope manually, but the transformer detects
when they are not properly paired.
+
+ The transformer allows keeping state across calls to visit_* that is local to
+ arbitrary nodes and their descendants, using the self.state attribute.
+ Multiple independent scopes are allowed and automatically constructed.
+
+ For example, to keep track of the If node that encloses any Name node, one can
+ write:
+
+ class FooType(object):
+
+ def __init__(self):
+ self.foo_property = None
+
+ class DummyTransformer(Base):
+
+ def visit_If(self, node):
+ self.state[FooType].enter()
+ self.state[FooType].foo_property = node
+
+ def visit_Name(self, node):
+ self.state[FooType].foo_property # will hold the innermost enclosing if
"""
# TODO(mdan): Document all extra features.
@@ -92,6 +210,12 @@ class Base(gast.NodeTransformer):
self._local_scope_state = []
self.enter_local_scope()
+ # Allows scoping of local variables to keep state across calls to visit_*
+ # methods. Multiple scope hierchies may exist and are keyed by tag. A scope
+ # is valid at one or more nodes and all its children. Scopes created in
+ # child nodes supersede their parent. Scopes are isolated from one another.
+ self.state = _State()
+
@property
def enclosing_entities(self):
return tuple(self._enclosing_entities)
@@ -101,7 +225,9 @@ class Base(gast.NodeTransformer):
return len(self._local_scope_state)
def enter_local_scope(self, inherit=None):
- """Marks entry into a new local scope.
+ """Deprecated. Use self.state instead.
+
+ Marks entry into a new local scope.
Args:
inherit: Optional enumerable of variable names to copy from the
@@ -116,7 +242,9 @@ class Base(gast.NodeTransformer):
self._local_scope_state.append(scope_entered)
def exit_local_scope(self, keep=None):
- """Marks exit from the current local scope.
+ """Deprecated. Use self.state instead.
+
+ Marks exit from the current local scope.
Args:
keep: Optional enumerable of variable names to copy into the
@@ -133,9 +261,11 @@ class Base(gast.NodeTransformer):
return scope_left
def set_local(self, name, value):
+ """Deprecated. Use self.state instead."""
self._local_scope_state[-1][name] = value
def get_local(self, name, default=None):
+ """Deprecated. Use self.state instead."""
return self._local_scope_state[-1].get(name, default)
def debug_print(self, node):
@@ -216,7 +346,7 @@ class Base(gast.NodeTransformer):
node_destination = new_destination
return results
- # TODO(mdan): Once we have error tracing, we may be able to just go to SSA.
+ # TODO(mdan): Remove.
def apply_to_single_assignments(self, targets, values, apply_fn):
"""Applies a function to each individual assignment.
@@ -266,19 +396,38 @@ class Base(gast.NodeTransformer):
def _get_source(self, node):
try:
- return compiler.ast_to_source(node)
- except AssertionError:
+ source, _ = compiler.ast_to_source(node)
+ return source
+ # pylint: disable=broad-except
+ # This function is used for error reporting. If an exception occurs here,
+ # it should be suppressed, in favor of emitting as informative a message
+ # about the original error as possible.
+ except Exception:
return '<could not convert AST to source>'
def visit(self, node):
+ if not isinstance(node, gast.AST):
+ # This is not that uncommon a mistake: various node bodies are lists, for
+ # example, posing a land mine for transformers that need to recursively
+ # call `visit`. The error needs to be raised before the exception handler
+ # below is installed, because said handler will mess up if `node` is not,
+ # in fact, a node.
+ msg = (
+ 'invalid value for "node": expected "ast.AST", got "{}"; to'
+ ' visit lists of nodes, use "visit_block" instead').format(type(node))
+ raise ValueError(msg)
+
source_code = self.entity_info.source_code
source_file = self.entity_info.source_file
did_enter_function = False
local_scope_size_at_entry = len(self._local_scope_state)
+ processing_expr_node = False
try:
if isinstance(node, (gast.FunctionDef, gast.ClassDef, gast.Lambda)):
did_enter_function = True
+ elif isinstance(node, gast.Expr):
+ processing_expr_node = True
if did_enter_function:
self._enclosing_entities.append(node)
@@ -287,9 +436,23 @@ class Base(gast.NodeTransformer):
self._lineno = node.lineno
self._col_offset = node.col_offset
+ if processing_expr_node:
+ entry_expr_value = node.value
+
if not anno.hasanno(node, anno.Basic.SKIP_PROCESSING):
result = super(Base, self).visit(node)
+ # Adjust for consistency: replacing the value of an Expr with
+ # an Assign node removes the need for the Expr node.
+ if processing_expr_node:
+ if isinstance(result, gast.Expr) and result.value != entry_expr_value:
+ # When the replacement is a list, it is assumed that the list came
+ # from a template that contained a number of statements, which
+ # themselves are standalone and don't require an enclosing Expr.
+ if isinstance(result.value,
+ (list, tuple, gast.Assign, gast.AugAssign)):
+ result = result.value
+
# On exception, the local scope integrity is not guaranteed.
if did_enter_function:
self._enclosing_entities.pop()