diff options
Diffstat (limited to 'tensorflow/python/ops/script_ops.py')
-rw-r--r-- | tensorflow/python/ops/script_ops.py | 35 |
1 files changed, 12 insertions, 23 deletions
diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index f8676ccb5f..219562de5d 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -23,6 +23,7 @@ import threading # Used by py_util.cc to get tracebacks. import traceback # pylint: disable=unused-import +import weakref import numpy as np import six @@ -129,11 +130,14 @@ class FuncRegistry(object): def __init__(self): self._lock = threading.Lock() self._unique_id = 0 # GUARDED_BY(self._lock) - self._funcs = {} + # Only store weakrefs to the funtions. The strong reference is stored in + # the graph. + self._funcs = weakref.WeakValueDictionary() def insert(self, func): """Registers `func` and returns a unique token for this entry.""" token = self._next_unique_token() + # Store a weakref to the function self._funcs[token] = func return token @@ -186,7 +190,7 @@ class FuncRegistry(object): Raises: ValueError: if no function is registered for `token`. """ - func = self._funcs[token] + func = self._funcs.get(token, None) if func is None: raise ValueError("callback %s is not found" % token) if isinstance(func, EagerFunc): @@ -228,19 +232,6 @@ _py_funcs = FuncRegistry() pywrap_tensorflow.InitializePyTrampoline(_py_funcs) -class CleanupFunc(object): - """A helper class to remove a registered function from _py_funcs.""" - - def __init__(self, token): - self._token = token - - def __del__(self): - if _py_funcs is not None: - # If _py_funcs is None, the program is most likely in shutdown, and the - # _py_funcs object has been destroyed already. - _py_funcs.remove(self._token) - - def _internal_py_func(func, inp, Tout, @@ -270,17 +261,15 @@ def _internal_py_func(func, # bound to that of the outer graph instead. graph = graph._outer_graph - cleanup = CleanupFunc(token) - # TODO(zhifengc): Consider adding a Graph method to collect # `cleanup` objects in one of its member. - if not hasattr(graph, "_cleanup_py_funcs_used_in_graph"): - graph._cleanup_py_funcs_used_in_graph = [] + if not hasattr(graph, "_py_funcs_used_in_graph"): + graph._py_funcs_used_in_graph = [] - # When `graph` is destroyed, elements in _cleanup_py_funcs_used_in_graph - # will be destroyed and their __del__ will remove the 'token' from - # the funcs registry. - graph._cleanup_py_funcs_used_in_graph.append(cleanup) + # Store a reference to the function in the graph to ensure it stays alive + # as long as the graph lives. When the graph is destroyed, the function + # is left to the garbage collector for destruction as well. + graph._py_funcs_used_in_graph.append(func) # pylint: enable=protected-access if eager: |