aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/static_analysis/live_values.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/live_values.py44
1 files changed, 26 insertions, 18 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
index 9ccb98f79a..2d8f922a45 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/live_values.py
@@ -16,7 +16,7 @@
Live values are extracted from the known execution context.
-Requires activity analysis annotations.
+Requires activity and reaching definitions analyses.
"""
from __future__ import absolute_import
@@ -45,14 +45,12 @@ class LiveValueResolver(transformer.Base):
def visit_Name(self, node):
self.generic_visit(node)
if isinstance(node.ctx, gast.Load):
- assert anno.hasanno(node, NodeAnno.IS_LOCAL), node
- symbol_is_local = anno.getanno(node, NodeAnno.IS_LOCAL)
- assert anno.hasanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY), node
- symbol_is_modified = anno.getanno(node, NodeAnno.IS_MODIFIED_SINCE_ENTRY)
- assert anno.hasanno(node, NodeAnno.IS_PARAM), node
- symbol_is_param = anno.getanno(node, NodeAnno.IS_PARAM)
-
- if not symbol_is_local and not symbol_is_param:
+ defs = anno.getanno(node, anno.Static.DEFINITIONS, ())
+
+ is_defined = bool(defs)
+ has_single_def = len(defs) == 1
+
+ if not is_defined:
if node.id in self.literals:
anno.setanno(node, 'live_val', self.literals[node.id])
elif node.id in self.entity_info.namespace:
@@ -79,11 +77,13 @@ class LiveValueResolver(transformer.Base):
# TODO(mdan): Attempt to trace its value through the local chain.
# TODO(mdan): Use type annotations as fallback.
- if not symbol_is_modified:
- if node.id in self.entity_info.arg_values:
- obj = self.entity_info.arg_values[node.id]
- anno.setanno(node, 'live_val', obj)
- anno.setanno(node, 'fqn', (obj.__class__.__name__,))
+ if has_single_def:
+ def_, = defs
+ if def_.param_of is self.enclosing_entities[0]:
+ if node.id in self.entity_info.arg_values:
+ obj = self.entity_info.arg_values[node.id]
+ anno.setanno(node, 'live_val', obj)
+ anno.setanno(node, 'fqn', (obj.__class__.__name__,))
return node
def visit_Attribute(self, node):
@@ -91,12 +91,20 @@ class LiveValueResolver(transformer.Base):
if anno.hasanno(node.value, 'live_val'):
assert anno.hasanno(node.value, 'fqn')
parent_object = anno.getanno(node.value, 'live_val')
- if not hasattr(parent_object, node.attr):
- raise AttributeError('%s has no attribute %s' % (parent_object,
- node.attr))
+
anno.setanno(node, 'parent_type', type(parent_object))
- anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,))
+ if hasattr(parent_object, node.attr):
+ # This can happen when the attribute's creation and use depend on the
+ # same static condition, for example:
+ #
+ # if cond:
+ # foo.bar = baz
+ # if cond:
+ # x = foo.bar
+ #
+ anno.setanno(node, 'live_val', getattr(parent_object, node.attr))
+
# TODO(mdan): Investigate the role built-in annotations can play here.
elif anno.hasanno(node.value, 'type'):
parent_type = anno.getanno(node.value, 'type')