diff options
author | 2017-04-05 15:39:39 -0800 | |
---|---|---|
committer | 2017-04-05 16:58:03 -0700 | |
commit | 046b7b0056340ddd2fff6f7d4552a4d942b9c87b (patch) | |
tree | 369b0803532c8d08189dedaa497801094616e2c3 | |
parent | db47f97ca884e0e368e4124cc6971346d1d41345 (diff) |
Fix the problem that no enough placeholders for persistent tensor
batch delete
The deleter_key is always a device_name, hence there is only one
of it. Hence, we cannot delete >1 handles at one time.
In the fix, it creates delete placeholder on demand, the max
number of placeholders is _DEAD_HANDLES_THRESHOLD.
Change: 152322770
-rw-r--r-- | tensorflow/python/client/session.py | 3 | ||||
-rw-r--r-- | tensorflow/python/ops/session_ops.py | 14 |
2 files changed, 6 insertions, 11 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 6900ac9a4f..5b50df3ed3 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -1094,8 +1094,9 @@ class BaseSession(SessionInterface): if tensors_to_delete: feeds = {} fetches = [] - for tensor_handle in tensors_to_delete: + for deleter_key, tensor_handle in enumerate(tensors_to_delete): holder, deleter = session_ops._get_handle_deleter(self.graph, + deleter_key, tensor_handle) feeds[holder] = tensor_handle fetches.append(deleter) diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py index 0a06982ad7..3d038cfd8a 100644 --- a/tensorflow/python/ops/session_ops.py +++ b/tensorflow/python/ops/session_ops.py @@ -116,7 +116,7 @@ class TensorHandle(object): raise TypeError("Persistent tensor %s may have already been deleted." % self.handle) self._auto_gc_enabled = False - holder, deleter = _get_handle_deleter(self._session.graph, self._handle) + holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle) self._session.run(deleter, feed_dict={holder: self.handle}) def get_raw_handle(self): @@ -142,11 +142,6 @@ class TensorHandle(object): return handle_parts[0] + ";" + handle_parts[-1] @staticmethod - def _get_deleter_key(handle): - """The graph key for deleter.""" - return str(handle).split(";")[-1] - - @staticmethod def _get_mover_key(feeder, handle): """The graph key for mover.""" return feeder.op.name + ";" + TensorHandle._get_reader_key(handle) @@ -302,10 +297,9 @@ def _get_handle_mover(graph, feeder, handle): return result -def _get_handle_deleter(graph, handle): +def _get_handle_deleter(graph, deleter_key, handle): """Return a deletion subgraph for this handle.""" - graph_key = TensorHandle._get_deleter_key(handle) - result = graph._handle_deleters.get(graph_key) + result = graph._handle_deleters.get(deleter_key) if result is None: # Create deleter if we haven't done it. handle_device = TensorHandle._get_device_name(handle) @@ -313,5 +307,5 @@ def _get_handle_deleter(graph, handle): holder = array_ops.placeholder(dtypes.string) deleter = gen_data_flow_ops._delete_session_tensor(holder) result = (holder, deleter) - graph._handle_deleters[graph_key] = result + graph._handle_deleters[deleter_key] = result return result |