From 046b7b0056340ddd2fff6f7d4552a4d942b9c87b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 5 Apr 2017 15:39:39 -0800 Subject: Fix the problem that no enough placeholders for persistent tensor batch delete The deleter_key is always a device_name, hence there is only one of it. Hence, we cannot delete >1 handles at one time. In the fix, it creates delete placeholder on demand, the max number of placeholders is _DEAD_HANDLES_THRESHOLD. Change: 152322770 --- tensorflow/python/client/session.py | 3 ++- tensorflow/python/ops/session_ops.py | 14 ++++---------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6900ac9a4f..5b50df3ed3 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1094,8 +1094,9 @@ class BaseSession(SessionInterface): if tensors_to_delete: feeds = {} fetches = [] - for tensor_handle in tensors_to_delete: + for deleter_key, tensor_handle in enumerate(tensors_to_delete): holder, deleter = session_ops._get_handle_deleter(self.graph, + deleter_key, tensor_handle) feeds[holder] = tensor_handle fetches.append(deleter) diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index 0a06982ad7..3d038cfd8a 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -116,7 +116,7 @@ class TensorHandle(object): raise TypeError("Persistent tensor %s may have already been deleted." % self.handle) self._auto_gc_enabled = False - holder, deleter = _get_handle_deleter(self._session.graph, self._handle) + holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) self._session.run(deleter, feed_dict={holder: self.handle}) def get_raw_handle(self): @@ -141,11 +141,6 @@ class TensorHandle(object): handle_parts = str(handle).split(";") return handle_parts[0] + ";" + handle_parts[-1] - @staticmethod - def _get_deleter_key(handle): - """The graph key for deleter.""" - return str(handle).split(";")[-1] - @staticmethod def _get_mover_key(feeder, handle): """The graph key for mover.""" @@ -302,10 +297,9 @@ def _get_handle_mover(graph, feeder, handle): return result -def _get_handle_deleter(graph, handle): +def _get_handle_deleter(graph, deleter_key, handle): """Return a deletion subgraph for this handle.""" - graph_key = TensorHandle._get_deleter_key(handle) - result = graph._handle_deleters.get(graph_key) + result = graph._handle_deleters.get(deleter_key) if result is None: # Create deleter if we haven't done it. handle_device = TensorHandle._get_device_name(handle) @@ -313,5 +307,5 @@ def _get_handle_deleter(graph, handle): holder = array_ops.placeholder(dtypes.string) deleter = gen_data_flow_ops._delete_session_tensor(holder) result = (holder, deleter) - graph._handle_deleters[graph_key] = result + graph._handle_deleters[deleter_key] = result return result -- cgit v1.2.3