aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/pyct/qual_names.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/pyct/qual_names.py')
-rw-r--r--tensorflow/contrib/autograph/pyct/qual_names.py28
1 files changed, 26 insertions, 2 deletions
diff --git a/tensorflow/contrib/autograph/pyct/qual_names.py b/tensorflow/contrib/autograph/pyct/qual_names.py
index da07013cf4..fb81404edc 100644
--- a/tensorflow/contrib/autograph/pyct/qual_names.py
+++ b/tensorflow/contrib/autograph/pyct/qual_names.py
@@ -30,6 +30,7 @@ import collections
import gast
from tensorflow.contrib.autograph.pyct import anno
+from tensorflow.contrib.autograph.pyct import parser
class Symbol(collections.namedtuple('Symbol', ['name'])):
@@ -89,7 +90,8 @@ class QN(object):
if not isinstance(base, (str, StringLiteral, NumberLiteral)):
# TODO(mdan): Require Symbol instead of string.
raise ValueError(
- 'For simple QNs, base must be a string or a Literal object.')
+ 'for simple QNs, base must be a string or a Literal object;'
+ ' got instead "%s"' % type(base))
assert '.' not in base and '[' not in base and ']' not in base
self._parent = None
self.qn = (base,)
@@ -113,6 +115,22 @@ class QN(object):
return self._parent
@property
+ def owner_set(self):
+ """Returns all the symbols (simple or composite) that own this QN.
+
+ In other words, if this symbol was modified, the symbols in the owner set
+ may also be affected.
+
+ Examples:
+ 'a.b[c.d]' has two owners, 'a' and 'a.b'
+ """
+ owners = set()
+ if self.has_attr() or self.has_subscript():
+ owners.add(self.parent)
+ owners.update(self.parent.owner_set)
+ return owners
+
+ @property
def support_set(self):
"""Returns the set of simple symbols that this QN relies on.
@@ -122,7 +140,7 @@ class QN(object):
Examples:
'a.b' has only one support symbol, 'a'
- 'a[i]' has two roots, 'a' and 'i'
+ 'a[i]' has two support symbols, 'a' and 'i'
"""
# TODO(mdan): This might be the set of Name nodes in the AST. Track those?
roots = set()
@@ -231,3 +249,9 @@ class QnResolver(gast.NodeTransformer):
def resolve(node):
return QnResolver().visit(node)
+
+
+def from_str(qn_str):
+ node = parser.parse_expression(qn_str)
+ node = resolve(node)
+ return anno.getanno(node, anno.Basic.QN)