aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-21 07:19:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 07:23:15 -0700
commit035a84769de2921667677b5530011bbd558ddf0c (patch)
tree3ff81e7ff0a6e50bd5b04fa486d173e0231ade54 /tensorflow/python/autograph
parent200b89761a4665e3de6d0efc4e3e10ab287ad81b (diff)
Use weakrefs where absolutely safe to do so, in order to reduce the number of circular references. Replace unnecessary OrderedDict with a regular dict.
PiperOrigin-RevId: 213982097
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/pyct/cfg.py13
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/activity.py6
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py3
3 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index 1433f9ac83..fca0eb62e4 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -27,6 +27,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import weakref
from enum import Enum
# pylint:disable=g-bad-import-order
@@ -61,7 +62,10 @@ class Node(object):
def freeze(self):
self.next = frozenset(self.next)
- self.prev = frozenset(self.prev)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ self.prev = weakref.WeakSet(self.prev)
def __repr__(self):
if isinstance(self.ast_node, gast.FunctionDef):
@@ -256,7 +260,7 @@ class GraphBuilder(object):
"""Resets the state of this factory."""
self.head = None
self.errors = set()
- self.node_index = collections.OrderedDict()
+ self.node_index = {}
# TODO(mdan): Too many primitives. Use classes.
self.leaves = set()
@@ -309,7 +313,10 @@ class GraphBuilder(object):
"""Grows the graph by adding a CFG node following the current leaves."""
if ast_node is self.node_index:
raise ValueError('%s added twice' % ast_node)
- node = Node(next_=set(), prev=set(), ast_node=ast_node)
+ # Assumption: All CFG nodes have identical life spans, because the graph
+ # owns them. Nodes should never be used outside the context of an existing
+ # graph.
+ node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node)
self.node_index[ast_node] = node
self.owners[node] = frozenset(self.active_stmts)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 9cb5991322..086eda7574 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
import copy
+import weakref
import gast
@@ -126,7 +127,10 @@ class Scope(object):
self.parent.mark_read(name)
def mark_param(self, name, owner):
- self.params[name] = owner
+ # Assumption: all AST nodes have the same life span. This lets us use
+ # a weak reference to mark the connection between a symbol node and the
+ # function node whose argument that symbol is.
+ self.params[name] = weakref.ref(owner)
def mark_creation(self, name, writes_create_symbol=False):
"""Mark a qualified name as created."""
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 3963772dad..36b9e7074d 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -89,7 +89,8 @@ class LiveValueResolver(transformer.Base):
if has_single_def:
def_, = defs
- if def_.param_of is self.enclosing_entities[0]:
+ # Note: param_of is a weakref.
+ if def_.param_of and def_.param_of() is self.enclosing_entities[0]:
if node.id in self.entity_info.arg_values:
obj = self.entity_info.arg_values[node.id]
anno.setanno(node, 'live_val', obj)