diff options
author | Asim Shankar <ashankar@google.com> | 2017-10-13 14:04:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-13 14:07:40 -0700 |
commit | 0c2a50e951bb840e84b0bc643b85f104c59a10ef (patch) | |
tree | 374f2fe7cb387c5f509c83e0c3b8fc1de05adf02 /tensorflow/python/eager/tape.py | |
parent | f688c35681623f38acdd9ba3a4db73fd092e13f3 (diff) |
eager: Fix issue with custom_gradients and implicit_gradients.
While at it, clean up some dead code/comments in tape.py
PiperOrigin-RevId: 172143125
Diffstat (limited to 'tensorflow/python/eager/tape.py')
-rw-r--r-- | tensorflow/python/eager/tape.py | 17 |
1 files changed, 1 insertions, 16 deletions
diff --git a/tensorflow/python/eager/tape.py b/tensorflow/python/eager/tape.py index 4578a7190d..76c6fa5ad8 100644 --- a/tensorflow/python/eager/tape.py +++ b/tensorflow/python/eager/tape.py @@ -22,8 +22,6 @@ import collections import contextlib import threading -from tensorflow.python.util import tf_contextlib - def tid(tensor): return tensor._id # pylint: disable=protected-access @@ -154,13 +152,6 @@ class _TapeStack(threading.local): def stack(self): return self._stack - @tf_contextlib.contextmanager - def replace_stack(self, new_stack): - old = self._stack - self._stack = new_stack - yield - self._stack = old - # The global tape stack. _tape_stack = _TapeStack() @@ -176,9 +167,6 @@ def watch(tensor): Args: tensor: tensor to be watched. - - Returns: - The tensor, potentially wrapped by all tapes in the stack. """ for t in _tape_stack.stack: t.watch(tensor) @@ -189,9 +177,6 @@ def watch_variable(variable): Args: variable: variable to be watched. - - Returns: - The tensor, potentially wrapped by all tapes in the stack. """ for t in _tape_stack.stack: t.watch_variable(variable) @@ -215,7 +200,7 @@ def stop_recording(): def should_record(tensors): - """Returns true if any tape in the stach watches any of these tensors.""" + """Returns true if any tape in the stack watches any of these tensors.""" if not _tape_stack.stack: return False return any(x.should_record(tensors) for x in _tape_stack.stack) |