diff options
author | Alexandre Passos <apassos@google.com> | 2017-09-14 19:14:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-14 19:19:08 -0700 |
commit | b51cd8cdce89cb93f3502d91cc61bc493422e215 (patch) | |
tree | c2e6dd8101809362151e2e309b6f4e53dbc95dad /tensorflow/python/eager/tape.py | |
parent | 39841e6c4fb2861dff7da486386fb3e3e3f8020d (diff) |
Eager gradient tape doesn't keep tensors alive.
PiperOrigin-RevId: 168782341
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 74 |
1 files changed, 51 insertions, 23 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index e33c52a1b2..899325cb20 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import collections +import contextlib import threading from tensorflow.python.util import tf_contextlib @@ -30,7 +31,8 @@ def tid(tensor): class TapeEntry( collections.namedtuple("TapeEntry", [ - "output_ids", "inputs", "side_outputs", "backward_function" + "output_ids", "input_ids", "side_outputs", "backward_function", + "output_shape_and_dtype", ])): """Entry in the gradient tape. @@ -39,11 +41,13 @@ class TapeEntry( Args: output_ids: tensor_id(t) for each output tensor T - inputs: input tensors - side_outputs: optional tensors which need to be provided to the backward - function. + input_ids: tensor_id(t) for each input tensor T + side_outputs: optional tensors (not IDs) which need to be provided to the + backward function. backward_function: function to be called with the downstream gradients and side outputs as arguments which computes the backward pass. + output_shape_and_dtype: a list of (shape_tuple, dtype) for every output + tensor_id """ @@ -57,8 +61,9 @@ class Tape(object): def __init__(self): # _tensor_tape maps from tensor IDs to their operation IDs self._tensor_tape = {} - # maps output tensor IDs to their shapes and dtypes - self._shape_dtype = {} + # maps from tensor ID to usage count. Triggers garbage collection when this + # goes to zero. + self._tensor_usage = {} # maps from operation ID to TapeEntry self._op_tape = {} # next operation ID @@ -81,8 +86,10 @@ class Tape(object): def watch(self, tensor): """Adds a tensor to the tape.""" - if tid(tensor) not in self._tensor_tape: - self._tensor_tape[tid(tensor)] = None + i = tid(tensor) + if i not in self._tensor_tape: + self._tensor_tape[i] = None + self._tensor_usage[i] = 1 self._watched.append(tensor) def watch_variable(self, v): @@ -95,25 +102,38 @@ class Tape(object): if not self.should_record(input_tensors): return output_tensors for t in output_tensors: - self._tensor_tape[tid(t)] = self._next_op_id - self._shape_dtype[tid(t)] = (_tensor_shape(t), t.dtype) - + i = tid(t) + self._tensor_tape[i] = self._next_op_id + self._tensor_usage[i] = 1 + for t in input_tensors: + i = tid(t) + self._tensor_usage[i] = self._tensor_usage.get(i, 0) + 1 self._op_tape[self._next_op_id] = TapeEntry( [tid(t) for t in output_tensors], - input_tensors, + [tid(t) for t in input_tensors], side_outputs, - backward_function) + backward_function, + [(_tensor_shape(t), t.dtype) for t in output_tensors]) self._next_op_id += 1 + def _delete_tensor_id(self, i): + if i in self._tensor_usage: + self._tensor_usage[i] -= 1 + if self._tensor_usage[i] == 0: + del self._tensor_usage[i] + op_id = self._tensor_tape.pop(i) + op = self._op_tape[op_id] + if not any(tensor_id in self._tensor_usage + for tensor_id in op.output_ids): + del self._op_tape[op_id] + for tensor_id in op.input_ids: + # TODO(apassos) this recursion might come to bite us. Consider + # adding an explicit stack if this ever gets out of hand + self._delete_tensor_id(tensor_id) + def delete_trace(self, tensor): """Deletes any trace we have for this tensor.""" - if tid(tensor) in self._tensor_tape: - op = self._tensor_tape[tid(tensor)] - del self._tensor_tape[tid(tensor)] - if op in self._op_tape: - if not any( - x in self._tensor_tape for x in self._op_tape[op].output_ids): - del self._op_tape[op] + self._delete_tensor_id(tid(tensor)) def export(self): """Exports the internal state of this tape. @@ -122,10 +142,8 @@ class Tape(object): tensor_tape: a map from tensor_id(tensor) to <identifier for op> responsible for generating that tensor. op_tape: a map from <identifier for op> to TapeEntry for that op. - output_to_shape_dtype: a map from tensor_id(tensor) to its shape and - dtype, for tensors which are outputs """ - return self._tensor_tape, self._op_tape, self._shape_dtype + return self._tensor_tape, self._op_tape class _TapeStack(threading.local): @@ -188,6 +206,16 @@ def pop_tape(): return None +@contextlib.contextmanager +def stop_recording(): + old = _tape_stack.stack + _tape_stack._stack = [] # pylint: disable=protected-access + try: + yield + finally: + _tape_stack._stack = old # pylint: disable=protected-access + + def should_record(tensors): """Returns true if any tape in the stach watches any of these tensors.""" if not _tape_stack.stack: |