diff options
Diffstat (limited to 'tensorflow/python/autograph/pyct/anno.py')
-rw-r--r-- | tensorflow/python/autograph/pyct/anno.py | 157 |
1 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py new file mode 100644 index 0000000000..1a52110ef3 --- /dev/null +++ b/tensorflow/python/autograph/pyct/anno.py @@ -0,0 +1,157 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AST node annotation support. + +Adapted from Tangent. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum + +# pylint:disable=g-bad-import-order +import gast +# pylint:enable=g-bad-import-order + + +# TODO(mdan): Shorten the names. +# These names are heavily used, and anno.blaa +# TODO(mdan): Replace the attr-dict mechanism with a more typed solution. + + +class NoValue(enum.Enum): + + def __repr__(self): + return self.name + + +class Basic(NoValue): + """Container for basic annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + QN = 'Qualified name, as it appeared in the code. See qual_names.py.' + SKIP_PROCESSING = ( + 'This node should be preserved as is and not processed any further.') + INDENT_BLOCK_REMAINDER = ( + 'When a node is annotated with this, the remainder of the block should' + ' be indented below it. The annotation contains a tuple' + ' (new_body, name_map), where `new_body` is the new indented block and' + ' `name_map` allows renaming symbols.') + ORIGIN = ('Information about the source code that converted code originated' + ' from. See origin_information.py.') + + +class Static(NoValue): + """Container for static analysis annotation keys. + + The enum values are used strictly for documentation purposes. + """ + + # Deprecated - use reaching definitions instead. + # Symbols + # These flags are boolean. + IS_LOCAL = 'Symbol is local to the function scope being analyzed.' + IS_PARAM = 'Symbol is a parameter to the function being analyzed.' + + # Scopes + # Scopes are represented by objects of type activity.Scope. + SCOPE = 'The scope for the annotated node. See activity.py.' + # TODO(mdan): Drop these in favor of accessing the child's SCOPE. + ARGS_SCOPE = 'The scope for the argument list of a function call.' + COND_SCOPE = 'The scope for the test node of a conditional statement.' + BODY_SCOPE = ( + 'The scope for the main body of a statement (True branch for if ' + 'statements, main body for loops).') + ORELSE_SCOPE = ( + 'The scope for the orelse body of a statement (False branch for if ' + 'statements, orelse body for loops).') + + # Static analysis annotations. + DEFINITIONS = ( + 'Reaching definition information. See reaching_definitions.py.') + ORIG_DEFINITIONS = ( + 'The value of DEFINITIONS that applied to the original code before any' + ' conversion.') + DEFINED_VARS_IN = ( + 'Symbols defined when entering the node. See reaching_definitions.py.') + LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.') + + +FAIL = object() + + +def keys(node, field_name='___pyct_anno'): + if not hasattr(node, field_name): + return frozenset() + return frozenset(getattr(node, field_name).keys()) + + +def getanno(node, key, default=FAIL, field_name='___pyct_anno'): + if (default is FAIL or (hasattr(node, field_name) and + (key in getattr(node, field_name)))): + return getattr(node, field_name)[key] + else: + return default + + +def hasanno(node, key, field_name='___pyct_anno'): + return hasattr(node, field_name) and key in getattr(node, field_name) + + +def setanno(node, key, value, field_name='___pyct_anno'): + annotations = getattr(node, field_name, {}) + setattr(node, field_name, annotations) + annotations[key] = value + + # So that the annotations survive gast_to_ast() and ast_to_gast() + if field_name not in node._fields: + node._fields += (field_name,) + + +def delanno(node, key, field_name='___pyct_anno'): + annotations = getattr(node, field_name) + del annotations[key] + if not annotations: + delattr(node, field_name) + node._fields = tuple(f for f in node._fields if f != field_name) + + +def copyanno(from_node, to_node, key, field_name='___pyct_anno'): + if hasanno(from_node, key, field_name=field_name): + setanno( + to_node, + key, + getanno(from_node, key, field_name=field_name), + field_name=field_name) + + +def dup(node, copy_map, field_name='___pyct_anno'): + """Recursively copies annotations in an AST tree. + + Args: + node: ast.AST + copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination + key. All annotations with the source key will be copied to identical + annotations with the destination key. + field_name: str + """ + for n in gast.walk(node): + for k in copy_map: + if hasanno(n, k, field_name): + setanno(n, copy_map[k], getanno(n, k, field_name), field_name) |