diff options
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/anno.py')
-rw-r--r-- | tensorflow/contrib/autograph/pyct/anno.py | 91 |
1 files changed, 80 insertions, 11 deletions
diff --git a/tensorflow/contrib/autograph/pyct/anno.py b/tensorflow/contrib/autograph/pyct/anno.py index ae861627fd..1a52110ef3 100644 --- a/tensorflow/contrib/autograph/pyct/anno.py +++ b/tensorflow/contrib/autograph/pyct/anno.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Handling annotations on AST nodes. +"""AST node annotation support. Adapted from Tangent. """ @@ -21,37 +21,90 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from enum import Enum +import enum +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order -class NoValue(Enum): + +# TODO(mdan): Shorten the names. +# These names are heavily used, and anno.blaa +# TODO(mdan): Replace the attr-dict mechanism with a more typed solution. + + +class NoValue(enum.Enum): def __repr__(self): return self.name class Basic(NoValue): - """Container for annotation keys. + """Container for basic annotation keys. The enum values are used strictly for documentation purposes. """ - QN = 'Qualified name, as it appeared in the code.' + QN = 'Qualified name, as it appeared in the code. See qual_names.py.' SKIP_PROCESSING = ( 'This node should be preserved as is and not processed any further.') INDENT_BLOCK_REMAINDER = ( - 'When a node is annotated with this, the remainder of the block should ' - 'be indented below it. The annotation contains a tuple ' - '(new_body, name_map), where `new_body` is the new indented block and ' - '`name_map` allows renaming symbols.') + 'When a node is annotated with this, the remainder of the block should' + ' be indented below it. The annotation contains a tuple' + ' (new_body, name_map), where `new_body` is the new indented block and' + ' `name_map` allows renaming symbols.') + ORIGIN = ('Information about the source code that converted code originated' + ' from. See origin_information.py.') + + +class Static(NoValue): + """Container for static analysis annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + # Deprecated - use reaching definitions instead. + # Symbols + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' + + # Scopes + # Scopes are represented by objects of type activity.Scope. + SCOPE = 'The scope for the annotated node. See activity.py.' + # TODO(mdan): Drop these in favor of accessing the child's SCOPE. + ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') + + # Static analysis annotations. + DEFINITIONS = ( + 'Reaching definition information. See reaching_definitions.py.') + ORIG_DEFINITIONS = ( + 'The value of DEFINITIONS that applied to the original code before any' + ' conversion.') + DEFINED_VARS_IN = ( + 'Symbols defined when entering the node. See reaching_definitions.py.') + LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') FAIL = object() +def keys(node, field_name='___pyct_anno'): + if not hasattr(node, field_name): + return frozenset() + return frozenset(getattr(node, field_name).keys()) + + def getanno(node, key, default=FAIL, field_name='___pyct_anno'): - if (default is FAIL or - (hasattr(node, field_name) and (key in getattr(node, field_name)))): + if (default is FAIL or (hasattr(node, field_name) and + (key in getattr(node, field_name)))): return getattr(node, field_name)[key] else: return default @@ -86,3 +139,19 @@ def copyanno(from_node, to_node, key, field_name='___pyct_anno'): key, getanno(from_node, key, field_name=field_name), field_name=field_name) + + +def dup(node, copy_map, field_name='___pyct_anno'): + """Recursively copies annotations in an AST tree. + + Args: + node: ast.AST + copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination + key. All annotations with the source key will be copied to identical + annotations with the destination key. + field_name: str + """ + for n in gast.walk(node): + for k in copy_map: + if hasanno(n, k, field_name): + setanno(n, copy_map[k], getanno(n, k, field_name), field_name) |