aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/tape.py
blob: c16aa8c2f7eb48002acd354b20f8ca06febcc6f7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Gradient tape utilites."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import contextlib
import threading

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.util import compat


def tid(tensor):
  return tensor._id  # pylint: disable=protected-access


class TapeEntry(
    collections.namedtuple("TapeEntry", [
        "op_type",
        "output_ids", "input_ids", "backward_function",
        "output_shape_and_dtype",
    ])):
  """Entry in the gradient tape.

  Represents the execution of one op or function, with instructions for doing
  its backward pass and useful information for it.

  Args:
   output_ids: tensor_id(t) for each output tensor T
   input_ids: tensor_id(t) for each input tensor T
   backward_function: function to be called with the downstream gradients and
    side outputs as arguments which computes the backward pass.
   output_shape_and_dtype: a list of (shape_tuple, dtype) for every output
    tensor_id
  """


def _tensor_shape(t):
  return t._shape_tuple()  # pylint: disable=protected-access


class Tape(object):
  """Represents a gradient propagation trace."""

  def __init__(self):
    self._tape = pywrap_tensorflow.TFE_Py_NewTape()
    self._watched_variables = set()

  def should_record(self, tensors):
    """Returns true if any tensor should be recorded.

    Args:
      tensors: some tensors.

    Returns:
      True if any of the tensors is in the tape.
    """
    return pywrap_tensorflow.TFE_Py_TapeShouldRecord(
        self._tape, [x._id  for x in tensors])  # pylint: disable=protected-access

  def watch(self, tensor):
    """Adds a tensor to the tape."""
    pywrap_tensorflow.TFE_Py_TapeWatch(self._tape, tid(tensor))

  def watch_variable(self, v):
    self._watched_variables.add(v)
    self.watch(v.handle)

  def record_operation(self, op_type, output_tensors, input_tensors,
                       backward_function):
    """Records an operation in the tape."""
    pywrap_tensorflow.TFE_Py_TapeRecordOperation(
        self._tape,
        compat.as_bytes(op_type),
        output_tensors,
        [x._id for x in input_tensors],  # pylint: disable=protected-access
        backward_function)

  def _delete_tensor_id(self, i):
    pywrap_tensorflow.TFE_Py_TapeDeleteTrace(self._tape, i)

  def delete_trace(self, tensor_id):
    """Deletes any trace we have for this tensor."""
    self._delete_tensor_id(tensor_id)

  def export(self):
    """Exports the internal state of this tape.

    Returns:
      tensor_tape: a map from tensor_id(tensor) to <identifier for op>
       responsible for generating that tensor.
      op_tape: a map from <identifier for op> to TapeEntry for that op.
    """
    return pywrap_tensorflow.TFE_Py_TapeExport(self._tape)


class _TapeStack(threading.local):

  def __init__(self):
    super(_TapeStack, self).__init__()
    self._stack = []

  @property
  def stack(self):
    return self._stack


# The global tape stack.
_tape_stack = _TapeStack()


def push_new_tape():
  """Pushes a new tape onto the tape stack."""
  _tape_stack.stack.append(Tape())


def watch(tensor):
  """Marks this tensor to be watched by all tapes in the stack.

  Args:
    tensor: tensor to be watched.
  """
  for t in _tape_stack.stack:
    t.watch(tensor)


def watch_variable(variable):
  """Marks this variable to be watched by all tapes in the stack.

  Args:
    variable: variable to be watched.
  """
  for t in _tape_stack.stack:
    t.watch_variable(variable)


def pop_tape():
  """Pops the top tape in the stack, if any."""
  if _tape_stack.stack:
    return _tape_stack.stack.pop()
  return None


@contextlib.contextmanager
def stop_recording():
  old = _tape_stack.stack
  _tape_stack._stack = []  # pylint: disable=protected-access
  try:
    yield
  finally:
    _tape_stack._stack = old  # pylint: disable=protected-access


def should_record(tensors):
  """Returns true if any tape in the stack watches any of these tensors."""
  if not _tape_stack.stack:
    return False
  return any(x.should_record(tensors) for x in _tape_stack.stack)


def record_operation(op_type, output_tensors, input_tensors, backward_function):
  """Records the operation on all tapes in the stack."""
  for t in _tape_stack.stack:
    t.record_operation(op_type, output_tensors,
                       input_tensors,
                       backward_function)


def delete_trace(tensor_id):
  """Deletes traces for this Tensor from all tapes in the stack."""
  for t in _tape_stack.stack:
    t.delete_trace(tensor_id)


def top_tape_watched_variables():
  t = _tape_stack.stack[-1]
  return t._watched_variables  # pylint: disable=protected-access


def could_possibly_record():
  """Returns True if any tape is active."""
  return len(_tape_stack.stack) > 0  # pylint: disable=g-explicit-length-test