aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/pyct/anno.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/pyct/anno.py')
-rw-r--r--tensorflow/python/autograph/pyct/anno.py157
1 files changed, 157 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
new file mode 100644
index 0000000000..1a52110ef3
--- /dev/null
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -0,0 +1,157 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""AST node annotation support.
+
+Adapted from Tangent.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import enum
+
+# pylint:disable=g-bad-import-order
+import gast
+# pylint:enable=g-bad-import-order
+
+
+# TODO(mdan): Shorten the names.
+# These names are heavily used, and anno.blaa
+# TODO(mdan): Replace the attr-dict mechanism with a more typed solution.
+
+
+class NoValue(enum.Enum):
+
+ def __repr__(self):
+ return self.name
+
+
+class Basic(NoValue):
+ """Container for basic annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ QN = 'Qualified name, as it appeared in the code. See qual_names.py.'
+ SKIP_PROCESSING = (
+ 'This node should be preserved as is and not processed any further.')
+ INDENT_BLOCK_REMAINDER = (
+ 'When a node is annotated with this, the remainder of the block should'
+ ' be indented below it. The annotation contains a tuple'
+ ' (new_body, name_map), where `new_body` is the new indented block and'
+ ' `name_map` allows renaming symbols.')
+ ORIGIN = ('Information about the source code that converted code originated'
+ ' from. See origin_information.py.')
+
+
+class Static(NoValue):
+ """Container for static analysis annotation keys.
+
+ The enum values are used strictly for documentation purposes.
+ """
+
+ # Deprecated - use reaching definitions instead.
+ # Symbols
+ # These flags are boolean.
+ IS_LOCAL = 'Symbol is local to the function scope being analyzed.'
+ IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
+
+ # Scopes
+ # Scopes are represented by objects of type activity.Scope.
+ SCOPE = 'The scope for the annotated node. See activity.py.'
+ # TODO(mdan): Drop these in favor of accessing the child's SCOPE.
+ ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
+ BODY_SCOPE = (
+ 'The scope for the main body of a statement (True branch for if '
+ 'statements, main body for loops).')
+ ORELSE_SCOPE = (
+ 'The scope for the orelse body of a statement (False branch for if '
+ 'statements, orelse body for loops).')
+
+ # Static analysis annotations.
+ DEFINITIONS = (
+ 'Reaching definition information. See reaching_definitions.py.')
+ ORIG_DEFINITIONS = (
+ 'The value of DEFINITIONS that applied to the original code before any'
+ ' conversion.')
+ DEFINED_VARS_IN = (
+ 'Symbols defined when entering the node. See reaching_definitions.py.')
+ LIVE_VARS_OUT = ('Symbols live when exiting the node. See liveness.py.')
+
+
+FAIL = object()
+
+
+def keys(node, field_name='___pyct_anno'):
+ if not hasattr(node, field_name):
+ return frozenset()
+ return frozenset(getattr(node, field_name).keys())
+
+
+def getanno(node, key, default=FAIL, field_name='___pyct_anno'):
+ if (default is FAIL or (hasattr(node, field_name) and
+ (key in getattr(node, field_name)))):
+ return getattr(node, field_name)[key]
+ else:
+ return default
+
+
+def hasanno(node, key, field_name='___pyct_anno'):
+ return hasattr(node, field_name) and key in getattr(node, field_name)
+
+
+def setanno(node, key, value, field_name='___pyct_anno'):
+ annotations = getattr(node, field_name, {})
+ setattr(node, field_name, annotations)
+ annotations[key] = value
+
+ # So that the annotations survive gast_to_ast() and ast_to_gast()
+ if field_name not in node._fields:
+ node._fields += (field_name,)
+
+
+def delanno(node, key, field_name='___pyct_anno'):
+ annotations = getattr(node, field_name)
+ del annotations[key]
+ if not annotations:
+ delattr(node, field_name)
+ node._fields = tuple(f for f in node._fields if f != field_name)
+
+
+def copyanno(from_node, to_node, key, field_name='___pyct_anno'):
+ if hasanno(from_node, key, field_name=field_name):
+ setanno(
+ to_node,
+ key,
+ getanno(from_node, key, field_name=field_name),
+ field_name=field_name)
+
+
+def dup(node, copy_map, field_name='___pyct_anno'):
+ """Recursively copies annotations in an AST tree.
+
+ Args:
+ node: ast.AST
+ copy_map: Dict[Hashable, Hashable], maps a source anno key to a destination
+ key. All annotations with the source key will be copied to identical
+ annotations with the destination key.
+ field_name: str
+ """
+ for n in gast.walk(node):
+ for k in copy_map:
+ if hasanno(n, k, field_name):
+ setanno(n, copy_map[k], getanno(n, k, field_name), field_name)