aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-10-09 16:44:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 16:50:04 -0700
commitef9d2e7be9ae9fbcd4720d46e1f8a8cac902a1cd (patch)
tree5e4f31d61d8ef62d3583af2b068e8402d8e9c4a7
parentd78c747e9177fc93d43a580acef2b62eb1420859 (diff)
Remove the deprecated created and IS_LOCAL abstractions from activity analysis.
PiperOrigin-RevId: 216446750
-rw-r--r--tensorflow/python/autograph/pyct/anno.py2
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py82
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity_test.py268
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py5
4 files changed, 121 insertions, 236 deletions
diff --git a/tensorflow/python/autograph/pyct/anno.py b/tensorflow/python/autograph/pyct/anno.py
index 5392e6ea03..e1f4af46cd 100644
--- a/tensorflow/python/autograph/pyct/anno.py
+++ b/tensorflow/python/autograph/pyct/anno.py
@@ -63,10 +63,8 @@ class Static(NoValue):
The enum values are used strictly for documentation purposes.
"""
- # Deprecated - use reaching definitions instead.
# Symbols
# These flags are boolean.
- IS_LOCAL = 'Symbol is local to the function scope being analyzed.'
IS_PARAM = 'Symbol is a parameter to the function being analyzed.'
# Scopes
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 086eda7574..cc159031ff 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -44,7 +44,6 @@ class Scope(object):
Attributes:
modified: identifiers modified in this scope
- created: identifiers created in this scope
used: identifiers referenced in this scope
"""
@@ -54,7 +53,8 @@ class Scope(object):
Args:
parent: A Scope or None.
isolated: Whether the scope is isolated, that is, whether variables
- created in this scope should be visible to the parent scope.
+ modified in this scope should be considered modified in the parent
+ scope.
add_unknown_symbols: Whether to handle attributed and subscripts
without having first seen the base name.
E.g., analyzing the statement 'x.y = z' without first having seen 'x'.
@@ -63,13 +63,11 @@ class Scope(object):
self.parent = parent
self.add_unknown_symbols = add_unknown_symbols
self.modified = set()
- # TODO(mdan): Completely remove this.
- self.created = set()
self.used = set()
self.params = {}
self.returned = set()
- # TODO(mdan): Rename to `locals`
+ # TODO(mdan): Rename to `reserved`
@property
def referenced(self):
if not self.isolated and self.parent is not None:
@@ -77,8 +75,7 @@ class Scope(object):
return self.used
def __repr__(self):
- return 'Scope{r=%s, c=%s, w=%s}' % (tuple(self.used), tuple(self.created),
- tuple(self.modified))
+ return 'Scope{r=%s, w=%s}' % (tuple(self.used), tuple(self.modified))
def copy_from(self, other):
"""Recursively copies the contents of this scope from another scope."""
@@ -88,7 +85,6 @@ class Scope(object):
self.parent.copy_from(other.parent)
self.isolated = other.isolated
self.modified = copy.copy(other.modified)
- self.created = copy.copy(other.created)
self.used = copy.copy(other.used)
self.params = copy.copy(other.params)
self.returned = copy.copy(other.returned)
@@ -109,56 +105,28 @@ class Scope(object):
if other.parent is not None:
self.parent.merge_from(other.parent)
self.modified |= other.modified
- self.created |= other.created
self.used |= other.used
self.params.update(other.params)
self.returned |= other.returned
- def has(self, name):
- if name in self.modified:
- return True
- elif self.parent is not None:
- return self.parent.has(name)
- return False
-
def mark_read(self, name):
self.used.add(name)
- if self.parent is not None and name not in self.created:
+ if self.parent is not None and name not in self.params:
self.parent.mark_read(name)
+ def mark_modified(self, name):
+ """Marks the given symbol as modified in the current scope."""
+ self.modified.add(name)
+ if not self.isolated:
+ if self.parent is not None:
+ self.parent.mark_modified(name)
+
def mark_param(self, name, owner):
# Assumption: all AST nodes have the same life span. This lets us use
# a weak reference to mark the connection between a symbol node and the
# function node whose argument that symbol is.
self.params[name] = weakref.ref(owner)
- def mark_creation(self, name, writes_create_symbol=False):
- """Mark a qualified name as created."""
- if name.is_composite():
- parent = name.parent
- if not writes_create_symbol:
- return
- else:
- if not self.has(parent):
- if self.add_unknown_symbols:
- self.mark_read(parent)
- else:
- raise ValueError('Unknown symbol "%s".' % parent)
- self.created.add(name)
-
- def mark_write(self, name):
- """Marks the given symbol as modified in the current scope."""
- self.modified.add(name)
- if self.isolated:
- self.mark_creation(name)
- else:
- if self.parent is None:
- self.mark_creation(name)
- else:
- if not self.parent.has(name):
- self.mark_creation(name)
- self.parent.mark_write(name)
-
def mark_returned(self, name):
self.returned.add(name)
if not self.isolated and self.parent is not None:
@@ -197,10 +165,7 @@ class ActivityAnalyzer(transformer.Base):
return True
return False
- def _track_symbol(self,
- node,
- composite_writes_alter_parent=False,
- writes_create_symbol=False):
+ def _track_symbol(self, node, composite_writes_alter_parent=False):
# A QN may be missing when we have an attribute (or subscript) on a function
# call. Example: a().b
if not anno.hasanno(node, anno.Basic.QN):
@@ -208,11 +173,9 @@ class ActivityAnalyzer(transformer.Base):
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Store):
- self.scope.mark_write(qn)
+ self.scope.mark_modified(qn)
if qn.is_composite and composite_writes_alter_parent:
- self.scope.mark_write(qn.parent)
- if writes_create_symbol:
- self.scope.mark_creation(qn, writes_create_symbol=True)
+ self.scope.mark_modified(qn.parent)
if self._in_aug_assign:
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Load):
@@ -220,13 +183,11 @@ class ActivityAnalyzer(transformer.Base):
elif isinstance(node.ctx, gast.Param):
# Param contexts appear in function defs, so they have the meaning of
# defining a variable.
- self.scope.mark_write(qn)
+ self.scope.mark_modified(qn)
self.scope.mark_param(qn, self.enclosing_entities[-1])
else:
raise ValueError('Unknown context %s for node %s.' % (type(node.ctx), qn))
- anno.setanno(node, NodeAnno.IS_LOCAL, self.scope.has(qn))
-
if self._in_return_statement:
self.scope.mark_returned(qn)
@@ -243,6 +204,12 @@ class ActivityAnalyzer(transformer.Base):
self._exit_scope()
return node
+ def visit_nonlocal(self, node):
+ raise NotImplementedError()
+
+ def visit_global(self, node):
+ raise NotImplementedError()
+
def visit_Expr(self, node):
return self._process_statement(node)
@@ -271,8 +238,7 @@ class ActivityAnalyzer(transformer.Base):
def visit_Attribute(self, node):
node = self.generic_visit(node)
if self._in_constructor and self._node_sets_self_attribute(node):
- self._track_symbol(
- node, composite_writes_alter_parent=True, writes_create_symbol=True)
+ self._track_symbol(node, composite_writes_alter_parent=True)
else:
self._track_symbol(node)
return node
@@ -336,7 +302,7 @@ class ActivityAnalyzer(transformer.Base):
# of its name, along with the usage of any decorator accompany it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
- self.scope.mark_write(qual_names.QN(node.name))
+ self.scope.mark_modified(qual_names.QN(node.name))
anno.setanno(node, anno.Static.SCOPE, self.scope)
self._exit_scope()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
index d4a6ce8ac3..9a4f1bf09b 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity_test.py
@@ -32,62 +32,63 @@ from tensorflow.python.platform import test
class ScopeTest(test.TestCase):
+ def assertMissing(self, qn, scope):
+ self.assertNotIn(qn, scope.used)
+ self.assertNotIn(qn, scope.modified)
+
+ def assertReadOnly(self, qn, scope):
+ self.assertIn(qn, scope.used)
+ self.assertNotIn(qn, scope.modified)
+
+ def assertWriteOnly(self, qn, scope):
+ self.assertNotIn(qn, scope.used)
+ self.assertIn(qn, scope.modified)
+
+ def assertReadWrite(self, qn, scope):
+ self.assertIn(qn, scope.used)
+ self.assertIn(qn, scope.modified)
+
def test_basic(self):
scope = activity.Scope(None)
- self.assertFalse(scope.has(QN('foo')))
+ self.assertMissing(QN('foo'), scope)
scope.mark_read(QN('foo'))
- self.assertFalse(scope.has(QN('foo')))
-
- scope.mark_write(QN('foo'))
- self.assertTrue(scope.has(QN('foo')))
+ self.assertReadOnly(QN('foo'), scope)
- scope.mark_read(QN('bar'))
- self.assertFalse(scope.has(QN('bar')))
+ scope.mark_modified(QN('foo'))
+ self.assertReadWrite(QN('foo'), scope)
def test_copy_from(self):
scope = activity.Scope(None)
- scope.mark_write(QN('foo'))
-
+ scope.mark_modified(QN('foo'))
other = activity.Scope(None)
other.copy_from(scope)
- self.assertTrue(QN('foo') in other.modified)
+ self.assertWriteOnly(QN('foo'), other)
- scope.mark_write(QN('bar'))
+ scope.mark_modified(QN('bar'))
scope.copy_from(other)
- self.assertFalse(QN('bar') in scope.modified)
+ self.assertMissing(QN('bar'), scope)
- scope.mark_write(QN('bar'))
+ scope.mark_modified(QN('bar'))
scope.merge_from(other)
- self.assertTrue(QN('bar') in scope.modified)
- self.assertFalse(QN('bar') in other.modified)
+ self.assertWriteOnly(QN('bar'), scope)
+ self.assertMissing(QN('bar'), other)
def test_copy_of(self):
scope = activity.Scope(None)
scope.mark_read(QN('foo'))
+ other = activity.Scope.copy_of(scope)
- self.assertTrue(QN('foo') in activity.Scope.copy_of(scope).used)
+ self.assertReadOnly(QN('foo'), other)
child_scope = activity.Scope(scope)
child_scope.mark_read(QN('bar'))
+ other = activity.Scope.copy_of(child_scope)
- self.assertTrue(QN('bar') in activity.Scope.copy_of(child_scope).used)
-
- def test_nesting(self):
- scope = activity.Scope(None)
- scope.mark_write(QN('foo'))
- scope.mark_read(QN('bar'))
-
- child = activity.Scope(scope)
- self.assertTrue(child.has(QN('foo')))
- self.assertTrue(scope.has(QN('foo')))
-
- child.mark_write(QN('bar'))
- self.assertTrue(child.has(QN('bar')))
- self.assertFalse(scope.has(QN('bar')))
+ self.assertReadOnly(QN('bar'), other)
def test_referenced(self):
scope = activity.Scope(None)
@@ -123,25 +124,6 @@ class ActivityAnalyzerTest(test.TestCase):
node = activity.resolve(node, entity_info)
return node, entity_info
- def test_local_markers(self):
-
- def test_fn(a): # pylint:disable=unused-argument
- b = c # pylint:disable=undefined-variable
- while b > 0:
- b -= 1
- return b
-
- node, _ = self._parse_and_analyze(test_fn)
- self.assertFalse(
- anno.getanno(node.body[0].body[0].value,
- NodeAnno.IS_LOCAL)) # c in b = c
- self.assertTrue(
- anno.getanno(node.body[0].body[1].test.left,
- NodeAnno.IS_LOCAL)) # b in b > 0
- self.assertTrue(
- anno.getanno(node.body[0].body[2].value,
- NodeAnno.IS_LOCAL)) # b in return b
-
def assertSymbolSetsAre(self, expected, actual, name):
expected = set(expected)
actual = set(str(s) for s in actual)
@@ -153,12 +135,10 @@ class ActivityAnalyzerTest(test.TestCase):
' Extra: %s\n' % (name.upper(), expected, actual,
expected - actual, actual - expected))
- def assertScopeIsRmc(self, scope, used, modified, created):
+ def assertScopeIs(self, scope, used, modified):
"""Assert the scope contains specific used, modified & created variables."""
self.assertSymbolSetsAre(used, scope.used, 'read')
self.assertSymbolSetsAre(modified, scope.modified, 'modified')
- # Created is deprecated, we're no longer verifying it.
- # self.assertSymbolSetsAre(created, scope.created, 'created')
def test_print_statement(self):
@@ -181,7 +161,7 @@ class ActivityAnalyzerTest(test.TestCase):
print_args_scope = anno.getanno(print_node, NodeAnno.ARGS_SCOPE)
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertScopeIsRmc(print_args_scope, ('a', 'b'), (), ())
+ self.assertScopeIs(print_args_scope, ('a', 'b'), ())
def test_call_args(self):
@@ -195,8 +175,8 @@ class ActivityAnalyzerTest(test.TestCase):
call_node = node.body[0].body[2].value
# We basically need to detect which variables are captured by the call
# arguments.
- self.assertScopeIsRmc(
- anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), (), ())
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'b'), ())
def test_call_args_attributes(self):
@@ -210,12 +190,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[1].value
- self.assertScopeIsRmc(
- anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
- ('a', 'a.b', 'a.c'),
- (),
- (),
- )
+ self.assertScopeIs(
+ anno.getanno(call_node, NodeAnno.ARGS_SCOPE), ('a', 'a.b', 'a.c'), ())
def test_call_args_subscripts(self):
@@ -230,12 +206,9 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
call_node = node.body[0].body[2].value
- self.assertScopeIsRmc(
+ self.assertScopeIs(
anno.getanno(call_node, NodeAnno.ARGS_SCOPE),
- ('a', 'a[0]', 'a[b]', 'b'),
- (),
- (),
- )
+ ('a', 'a[0]', 'a[b]', 'b'), ())
def test_while(self):
@@ -248,14 +221,13 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
while_node = node.body[0].body[1]
- self.assertScopeIsRmc(
- anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'),
- ('c',))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
+ self.assertScopeIs(
anno.getanno(while_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
- ('b', 'c'), ('a', 'b', 'c'))
- self.assertScopeIsRmc(
- anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), (), ())
+ ('b', 'c'))
+ self.assertScopeIs(
+ anno.getanno(while_node, NodeAnno.COND_SCOPE), ('b',), ())
def test_for(self):
@@ -268,11 +240,11 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
for_node = node.body[0].body[1]
- self.assertScopeIsRmc(
- anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'), ('c',))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(for_node, NodeAnno.BODY_SCOPE), ('b',), ('b', 'c'))
+ self.assertScopeIs(
anno.getanno(for_node, NodeAnno.BODY_SCOPE).parent, ('a', 'b', 'c'),
- ('b', 'c', '_'), ('a', 'b', 'c', '_'))
+ ('b', 'c', '_'))
def test_if(self):
@@ -289,18 +261,16 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'),
- ('y', 'z'))
- # TODO(mdan): Double check: is it ok to not mark a local symbol as not read?
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'z', 'u'),
- ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('x', 'y', 'z'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('x', 'y', 'z', 'u'),
+ ('x', 'y', 'z', 'u'))
+ self.assertScopeIs(
anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('x', 'y'),
- ('x', 'y', 'u'), ('y', 'u'))
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'),
+ ('x', 'y', 'u'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u'))
def test_if_attributes(self):
@@ -316,24 +286,14 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE),
- ('a', 'a.c'),
- ('a.b', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
- ('a', 'a.c'),
- ('a.b', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent,
- ('a', 'a.c', 'd'),
- ('a.b', 'd'),
- ('a', 'd'),
- )
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'a.c'), ('a.b', 'd'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'a.c'),
+ ('a.b', 'd'))
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE).parent, ('a', 'a.c', 'd'),
+ ('a.b', 'd'))
def test_if_subscripts(self):
@@ -348,25 +308,15 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
if_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.BODY_SCOPE),
- ('a', 'b', 'c', 'a[c]'),
- ('a[b]', 'd'),
- ('d',),
- )
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.BODY_SCOPE), ('a', 'b', 'c', 'a[c]'),
+ ('a[b]', 'd'))
# TODO(mdan): Should subscript writes (a[0] = 1) be considered to read "a"?
- self.assertScopeIsRmc(
- anno.getanno(if_node, NodeAnno.ORELSE_SCOPE),
- ('a', 'e'),
- ('a[0]', 'd'),
- ('d',),
- )
- self.assertScopeIsRmc(
+ self.assertScopeIs(
+ anno.getanno(if_node, NodeAnno.ORELSE_SCOPE), ('a', 'e'), ('a[0]', 'd'))
+ self.assertScopeIs(
anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent,
- ('a', 'b', 'c', 'd', 'e', 'a[c]'),
- ('d', 'a[b]', 'a[0]'),
- ('a', 'b', 'c', 'd', 'e'),
- )
+ ('a', 'b', 'c', 'd', 'e', 'a[c]'), ('d', 'a[b]', 'a[0]'))
def test_nested_if(self):
@@ -380,12 +330,10 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
inner_if_node = node.body[0].body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',),
- ('a',))
- self.assertScopeIsRmc(
- anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',),
- ('a',))
+ self.assertScopeIs(
+ anno.getanno(inner_if_node, NodeAnno.BODY_SCOPE), ('b',), ('a',))
+ self.assertScopeIs(
+ anno.getanno(inner_if_node, NodeAnno.ORELSE_SCOPE), ('b',), ('a',))
def test_nested_function(self):
@@ -404,11 +352,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_def_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), (
- 'x',
- 'y',
- ))
+ self.assertScopeIs(
+ anno.getanno(fn_def_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',))
def test_constructor_attributes(self):
@@ -420,12 +365,9 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(TestClass)
init_node = node.body[0].body[0]
- self.assertScopeIsRmc(
- anno.getanno(init_node, NodeAnno.BODY_SCOPE),
- ('self', 'a', 'self.b'),
- ('self', 'self.b', 'self.b.c'),
- ('self', 'a', 'self.b'),
- )
+ self.assertScopeIs(
+ anno.getanno(init_node, NodeAnno.BODY_SCOPE), ('self', 'a', 'self.b'),
+ ('self', 'self.b', 'self.b.c'))
def test_aug_assign_subscripts(self):
@@ -434,12 +376,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('a', 'a[0]'),
- ('a[0]',),
- ('a',),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'a[0]'), ('a[0]',))
def test_return_vars_are_read(self):
@@ -448,16 +386,7 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('c',),
- (),
- (
- 'a',
- 'b',
- 'c',
- ),
- )
+ self.assertScopeIs(anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('c',), ())
def test_aug_assign(self):
@@ -466,12 +395,8 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('a', 'b'),
- ('a'),
- ('a', 'b'),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('a', 'b'), ('a'))
def test_aug_assign_rvalues(self):
@@ -485,23 +410,22 @@ class ActivityAnalyzerTest(test.TestCase):
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE),
- ('foo', 'x'),
- (),
- ('x',),
- )
+ self.assertScopeIs(
+ anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('foo', 'x'), ())
- def test_params_created(self):
+ def test_params(self):
def test_fn(a, b): # pylint: disable=unused-argument
return b
node, _ = self._parse_and_analyze(test_fn)
fn_node = node.body[0]
- self.assertScopeIsRmc(
- anno.getanno(fn_node, NodeAnno.BODY_SCOPE), ('b',), (('')),
- (('a', 'b')))
+ body_scope = anno.getanno(fn_node, NodeAnno.BODY_SCOPE)
+ self.assertScopeIs(body_scope, ('b',), ())
+ self.assertScopeIs(body_scope.parent, ('b',), ('a', 'b'))
+
+ args_scope = anno.getanno(fn_node.args, anno.Static.SCOPE)
+ self.assertSymbolSetsAre(('a', 'b'), args_scope.params.keys(), 'params')
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 4ceddce53b..dc363f9a47 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -28,7 +28,6 @@ import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import transformer
-from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# TODO(aqj): Do we need this? Do other builtins fail in similar ways
@@ -133,11 +132,9 @@ class LiveValueResolver(transformer.Base):
anno.setanno(node, 'fqn',
anno.getanno(node.value, 'type_fqn') + (node.attr,))
elif isinstance(node.value, gast.Name):
- stem_name = node.value
- # All nonlocal symbols should be fully resolved.
- assert anno.hasanno(stem_name, NodeAnno.IS_LOCAL), stem_name
# TODO(mdan): Figure out what to do when calling attribute on local object
# Maybe just leave as-is?
+ pass
return node