aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2017-09-14 19:14:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-14 19:19:08 -0700
commitb51cd8cdce89cb93f3502d91cc61bc493422e215 (patch)
treec2e6dd8101809362151e2e309b6f4e53dbc95dad /tensorflow/python/eager/tape.py
parent39841e6c4fb2861dff7da486386fb3e3e3f8020d (diff)
Eager gradient tape doesn't keep tensors alive.
PiperOrigin-RevId: 168782341
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r--tensorflow/python/eager/tape.py74
1 files changed, 51 insertions, 23 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py
index e33c52a1b2..899325cb20 100644
--- a/tensorflow/python/eager/tape.py
+++ b/tensorflow/python/eager/tape.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import threading
from tensorflow.python.util import tf_contextlib
@@ -30,7 +31,8 @@ def tid(tensor):
class TapeEntry(
collections.namedtuple("TapeEntry", [
- "output_ids", "inputs", "side_outputs", "backward_function"
+ "output_ids", "input_ids", "side_outputs", "backward_function",
+ "output_shape_and_dtype",
])):
"""Entry in the gradient tape.
@@ -39,11 +41,13 @@ class TapeEntry(
Args:
output_ids: tensor_id(t) for each output tensor T
- inputs: input tensors
- side_outputs: optional tensors which need to be provided to the backward
- function.
+ input_ids: tensor_id(t) for each input tensor T
+ side_outputs: optional tensors (not IDs) which need to be provided to the
+ backward function.
backward_function: function to be called with the downstream gradients and
side outputs as arguments which computes the backward pass.
+ output_shape_and_dtype: a list of (shape_tuple, dtype) for every output
+ tensor_id
"""
@@ -57,8 +61,9 @@ class Tape(object):
def __init__(self):
# _tensor_tape maps from tensor IDs to their operation IDs
self._tensor_tape = {}
- # maps output tensor IDs to their shapes and dtypes
- self._shape_dtype = {}
+ # maps from tensor ID to usage count. Triggers garbage collection when this
+ # goes to zero.
+ self._tensor_usage = {}
# maps from operation ID to TapeEntry
self._op_tape = {}
# next operation ID
@@ -81,8 +86,10 @@ class Tape(object):
def watch(self, tensor):
"""Adds a tensor to the tape."""
- if tid(tensor) not in self._tensor_tape:
- self._tensor_tape[tid(tensor)] = None
+ i = tid(tensor)
+ if i not in self._tensor_tape:
+ self._tensor_tape[i] = None
+ self._tensor_usage[i] = 1
self._watched.append(tensor)
def watch_variable(self, v):
@@ -95,25 +102,38 @@ class Tape(object):
if not self.should_record(input_tensors):
return output_tensors
for t in output_tensors:
- self._tensor_tape[tid(t)] = self._next_op_id
- self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype)
-
+ i = tid(t)
+ self._tensor_tape[i] = self._next_op_id
+ self._tensor_usage[i] = 1
+ for t in input_tensors:
+ i = tid(t)
+ self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1
self._op_tape[self._next_op_id] = TapeEntry(
[tid(t) for t in output_tensors],
- input_tensors,
+ [tid(t) for t in input_tensors],
side_outputs,
- backward_function)
+ backward_function,
+ [(_tensor_shape(t), t.dtype) for t in output_tensors])
self._next_op_id += 1
+ def _delete_tensor_id(self, i):
+ if i in self._tensor_usage:
+ self._tensor_usage[i] -= 1
+ if self._tensor_usage[i] == 0:
+ del self._tensor_usage[i]
+ op_id = self._tensor_tape.pop(i)
+ op = self._op_tape[op_id]
+ if not any(tensor_id in self._tensor_usage
+ for tensor_id in op.output_ids):
+ del self._op_tape[op_id]
+ for tensor_id in op.input_ids:
+ # TODO(apassos) this recursion might come to bite us. Consider
+ # adding an explicit stack if this ever gets out of hand
+ self._delete_tensor_id(tensor_id)
+
def delete_trace(self, tensor):
"""Deletes any trace we have for this tensor."""
- if tid(tensor) in self._tensor_tape:
- op = self._tensor_tape[tid(tensor)]
- del self._tensor_tape[tid(tensor)]
- if op in self._op_tape:
- if not any(
- x in self._tensor_tape for x in self._op_tape[op].output_ids):
- del self._op_tape[op]
+ self._delete_tensor_id(tid(tensor))
def export(self):
"""Exports the internal state of this tape.
@@ -122,10 +142,8 @@ class Tape(object):
tensor_tape: a map from tensor_id(tensor) to <identifier for op>
responsible for generating that tensor.
op_tape: a map from <identifier for op> to TapeEntry for that op.
- output_to_shape_dtype: a map from tensor_id(tensor) to its shape and
- dtype, for tensors which are outputs
"""
- return self._tensor_tape, self._op_tape, self._shape_dtype
+ return self._tensor_tape, self._op_tape
class _TapeStack(threading.local):
@@ -188,6 +206,16 @@ def pop_tape():
return None
+@contextlib.contextmanager
+def stop_recording():
+ old = _tape_stack.stack
+ _tape_stack._stack = [] # pylint: disable=protected-access
+ try:
+ yield
+ finally:
+ _tape_stack._stack = old # pylint: disable=protected-access
+
+
def should_record(tensors):
"""Returns true if any tape in the stach watches any of these tensors."""
if not _tape_stack.stack: