aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/script_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/script_ops.py')
-rw-r--r--tensorflow/python/ops/script_ops.py35
1 files changed, 12 insertions, 23 deletions
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py
index f8676ccb5f..219562de5d 100644
--- a/tensorflow/python/ops/script_ops.py
+++ b/tensorflow/python/ops/script_ops.py
@@ -23,6 +23,7 @@ import threading
# Used by py_util.cc to get tracebacks.
import traceback # pylint: disable=unused-import
+import weakref
import numpy as np
import six
@@ -129,11 +130,14 @@ class FuncRegistry(object):
def __init__(self):
self._lock = threading.Lock()
self._unique_id = 0 # GUARDED_BY(self._lock)
- self._funcs = {}
+ # Only store weakrefs to the funtions. The strong reference is stored in
+ # the graph.
+ self._funcs = weakref.WeakValueDictionary()
def insert(self, func):
"""Registers `func` and returns a unique token for this entry."""
token = self._next_unique_token()
+ # Store a weakref to the function
self._funcs[token] = func
return token
@@ -186,7 +190,7 @@ class FuncRegistry(object):
Raises:
ValueError: if no function is registered for `token`.
"""
- func = self._funcs[token]
+ func = self._funcs.get(token, None)
if func is None:
raise ValueError("callback %s is not found" % token)
if isinstance(func, EagerFunc):
@@ -228,19 +232,6 @@ _py_funcs = FuncRegistry()
pywrap_tensorflow.InitializePyTrampoline(_py_funcs)
-class CleanupFunc(object):
- """A helper class to remove a registered function from _py_funcs."""
-
- def __init__(self, token):
- self._token = token
-
- def __del__(self):
- if _py_funcs is not None:
- # If _py_funcs is None, the program is most likely in shutdown, and the
- # _py_funcs object has been destroyed already.
- _py_funcs.remove(self._token)
-
-
def _internal_py_func(func,
inp,
Tout,
@@ -270,17 +261,15 @@ def _internal_py_func(func,
# bound to that of the outer graph instead.
graph = graph._outer_graph
- cleanup = CleanupFunc(token)
-
# TODO(zhifengc): Consider adding a Graph method to collect
# `cleanup` objects in one of its member.
- if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"):
- graph._cleanup_py_funcs_used_in_graph = []
+ if not hasattr(graph, "_py_funcs_used_in_graph"):
+ graph._py_funcs_used_in_graph = []
- # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph
- # will be destroyed and their __del__ will remove the 'token' from
- # the funcs registry.
- graph._cleanup_py_funcs_used_in_graph.append(cleanup)
+ # Store a reference to the function in the graph to ensure it stays alive
+ # as long as the graph lives. When the graph is destroyed, the function
+ # is left to the garbage collector for destruction as well.
+ graph._py_funcs_used_in_graph.append(func)
# pylint: enable=protected-access
if eager: