aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-05 15:39:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-05 16:58:03 -0700
commit046b7b0056340ddd2fff6f7d4552a4d942b9c87b (patch)
tree369b0803532c8d08189dedaa497801094616e2c3
parentdb47f97ca884e0e368e4124cc6971346d1d41345 (diff)
Fix the problem that no enough placeholders for persistent tensor
batch delete The deleter_key is always a device_name, hence there is only one of it. Hence, we cannot delete >1 handles at one time. In the fix, it creates delete placeholder on demand, the max number of placeholders is _DEAD_HANDLES_THRESHOLD. Change: 152322770
-rw-r--r--tensorflow/python/client/session.py3
-rw-r--r--tensorflow/python/ops/session_ops.py14
2 files changed, 6 insertions, 11 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 6900ac9a4f..5b50df3ed3 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -1094,8 +1094,9 @@ class BaseSession(SessionInterface):
if tensors_to_delete:
feeds = {}
fetches = []
- for tensor_handle in tensors_to_delete:
+ for deleter_key, tensor_handle in enumerate(tensors_to_delete):
holder, deleter = session_ops._get_handle_deleter(self.graph,
+ deleter_key,
tensor_handle)
feeds[holder] = tensor_handle
fetches.append(deleter)
diff --git a/tensorflow/python/ops/session_ops.py b/tensorflow/python/ops/session_ops.py
index 0a06982ad7..3d038cfd8a 100644
--- a/tensorflow/python/ops/session_ops.py
+++ b/tensorflow/python/ops/session_ops.py
@@ -116,7 +116,7 @@ class TensorHandle(object):
raise TypeError("Persistent tensor %s may have already been deleted."
% self.handle)
self._auto_gc_enabled = False
- holder, deleter = _get_handle_deleter(self._session.graph, self._handle)
+ holder, deleter = _get_handle_deleter(self._session.graph, 0, self._handle)
self._session.run(deleter, feed_dict={holder: self.handle})
def get_raw_handle(self):
@@ -142,11 +142,6 @@ class TensorHandle(object):
return handle_parts[0] + ";" + handle_parts[-1]
@staticmethod
- def _get_deleter_key(handle):
- """The graph key for deleter."""
- return str(handle).split(";")[-1]
-
- @staticmethod
def _get_mover_key(feeder, handle):
"""The graph key for mover."""
return feeder.op.name + ";" + TensorHandle._get_reader_key(handle)
@@ -302,10 +297,9 @@ def _get_handle_mover(graph, feeder, handle):
return result
-def _get_handle_deleter(graph, handle):
+def _get_handle_deleter(graph, deleter_key, handle):
"""Return a deletion subgraph for this handle."""
- graph_key = TensorHandle._get_deleter_key(handle)
- result = graph._handle_deleters.get(graph_key)
+ result = graph._handle_deleters.get(deleter_key)
if result is None:
# Create deleter if we haven't done it.
handle_device = TensorHandle._get_device_name(handle)
@@ -313,5 +307,5 @@ def _get_handle_deleter(graph, handle):
holder = array_ops.placeholder(dtypes.string)
deleter = gen_data_flow_ops._delete_session_tensor(holder)
result = (holder, deleter)
- graph._handle_deleters[graph_key] = result
+ graph._handle_deleters[deleter_key] = result
return result