From 035a84769de2921667677b5530011bbd558ddf0c Mon Sep 17 00:00:00 2001 From: Dan Moldovan Date: Fri, 21 Sep 2018 07:19:09 -0700 Subject: Use weakrefs where absolutely safe to do so, in order to reduce the number of circular references. Replace unnecessary OrderedDict with a regular dict. PiperOrigin-RevId: 213982097 --- tensorflow/python/autograph/pyct/cfg.py | 13 ++++++++++--- .../python/autograph/pyct/static_analysis/activity.py | 6 +++++- .../python/autograph/pyct/static_analysis/live_values.py | 3 ++- 3 files changed, 17 insertions(+), 5 deletions(-) (limited to 'tensorflow/python/autograph') diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py index 1433f9ac83..fca0eb62e4 100644 --- a/tensorflow/python/autograph/pyct/cfg.py +++ b/tensorflow/python/autograph/pyct/cfg.py @@ -27,6 +27,7 @@ from __future__ import division from __future__ import print_function import collections +import weakref from enum import Enum # pylint:disable=g-bad-import-order @@ -61,7 +62,10 @@ class Node(object): def freeze(self): self.next = frozenset(self.next) - self.prev = frozenset(self.prev) + # Assumption: All CFG nodes have identical life spans, because the graph + # owns them. Nodes should never be used outside the context of an existing + # graph. + self.prev = weakref.WeakSet(self.prev) def __repr__(self): if isinstance(self.ast_node, gast.FunctionDef): @@ -256,7 +260,7 @@ class GraphBuilder(object): """Resets the state of this factory.""" self.head = None self.errors = set() - self.node_index = collections.OrderedDict() + self.node_index = {} # TODO(mdan): Too many primitives. Use classes. self.leaves = set() @@ -309,7 +313,10 @@ class GraphBuilder(object): """Grows the graph by adding a CFG node following the current leaves.""" if ast_node is self.node_index: raise ValueError('%s added twice' % ast_node) - node = Node(next_=set(), prev=set(), ast_node=ast_node) + # Assumption: All CFG nodes have identical life spans, because the graph + # owns them. Nodes should never be used outside the context of an existing + # graph. + node = Node(next_=set(), prev=weakref.WeakSet(), ast_node=ast_node) self.node_index[ast_node] = node self.owners[node] = frozenset(self.active_stmts) diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py index 9cb5991322..086eda7574 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/activity.py +++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function import copy +import weakref import gast @@ -126,7 +127,10 @@ class Scope(object): self.parent.mark_read(name) def mark_param(self, name, owner): - self.params[name] = owner + # Assumption: all AST nodes have the same life span. This lets us use + # a weak reference to mark the connection between a symbol node and the + # function node whose argument that symbol is. + self.params[name] = weakref.ref(owner) def mark_creation(self, name, writes_create_symbol=False): """Mark a qualified name as created.""" diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py index 3963772dad..36b9e7074d 100644 --- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py +++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py @@ -89,7 +89,8 @@ class LiveValueResolver(transformer.Base): if has_single_def: def_, = defs - if def_.param_of is self.enclosing_entities[0]: + # Note: param_of is a weakref. + if def_.param_of and def_.param_of() is self.enclosing_entities[0]: if node.id in self.entity_info.arg_values: obj = self.entity_info.arg_values[node.id] anno.setanno(node, 'live_val', obj) -- cgit v1.2.3