aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity.py18
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py2
-rw-r--r--tensorflow/contrib/autograph/pyct/static_analysis/annos.py1
3 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
index da6a2f6f05..6dd53091fa 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity.py
@@ -265,10 +265,10 @@ class ActivityAnalizer(transformer.Base):
qn = QN(node.name)
self.scope.mark_write(qn)
current_scope = self.scope
- fndef_scope = Scope(current_scope, isolated=True)
- self.scope = fndef_scope
+ body_scope = Scope(current_scope, isolated=True)
+ self.scope = body_scope
self.generic_visit(node)
- anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope)
+ anno.setanno(node, NodeAnno.BODY_SCOPE, body_scope)
self.scope = current_scope
return node
@@ -282,7 +282,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_If(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
@@ -297,7 +303,13 @@ class ActivityAnalizer(transformer.Base):
return node
def visit_While(self, node):
+ current_scope = self.scope
+ cond_scope = Scope(current_scope, isolated=False)
+ self.scope = cond_scope
self.visit(node.test)
+ anno.setanno(node, NodeAnno.COND_SCOPE, cond_scope)
+ self.scope = current_scope
+
node = self._process_parallel_blocks(node,
((node.body, NodeAnno.BODY_SCOPE),
(node.orelse, NodeAnno.ORELSE_SCOPE)))
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
index 37c28872bb..1e6c686b01 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/activity_test.py
@@ -204,6 +204,8 @@ class ActivityAnalizerTest(test.TestCase):
self.assertScopeIsRmc(
anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
('b', 'c'), ('a', 'b', 'c'))
+ self.assertScopeIsRmc(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
def test_for(self):
diff --git a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
index 5254b83ca7..d6d9f7e1a6 100644
--- a/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
+++ b/tensorflow/contrib/autograph/pyct/static_analysis/annos.py
@@ -43,6 +43,7 @@ class NodeAnno(NoValue):
# Scopes
# Scopes are represented by objects of type activity.Scope.
ARGS_SCOPE = 'The scope for the argument list of a function call.'
+ COND_SCOPE = 'The scope for the test node of a conditional statement.'
BODY_SCOPE = (
'The scope for the main body of a statement (True branch for if '
'statements, main body for loops).')