aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-05 09:57:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-05 10:07:37 -0700
commita0c80b9a54dc9669c0f5d151bee9f0b3a4fd71a0 (patch)
tree4fc8bd56c0c6922bc9de1dce3e5c0feb116fe4e0
parent16b233c43fbfc366a3ca3cebb2c5a5e32354263e (diff)
Expand activity analysis to the test nodes of if and while statements.
PiperOrigin-RevId: 191756234
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity.py18
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/annos.py1
3 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
index da6a2f6f05..6dd53091fa 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
@@ -265,10 +265,10 @@ class ActivityAnalizer(transformer.Base):
qn = QN(node.name)
self.scope.mark_write(qn)
current_scope = self.scope
- fndef_scope = Scope(current_scope, isolated=True)
- self.scope = fndef_scope
+ body_scope = Scope(current_scope, isolated=True)
+ self.scope = body_scope
self.generic_visit(node)
- anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
self.scope = current_scope
return node
@@ -282,7 +282,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_If(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
@@ -297,7 +303,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_While(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
index 37c28872bb..1e6c686b01 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
@@ -204,6 +204,8 @@ class ActivityAnalizerTest(test.TestCase):
self.assertScopeIsRmc(
anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
def test_for(self):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
index 5254b83ca7..d6d9f7e1a6 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
@@ -43,6 +43,7 @@ class NodeAnno(NoValue):
# Scopes
# Scopes are represented by objects of type activity.Scope.
ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
BODY_SCOPE = (
'The scope for the main body of a statement (True branch for if '
'statements, main body for loops).')