diff options
author | Alexandre Passos <apassos@google.com> | 2018-09-18 12:59:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-18 13:05:10 -0700 |
commit | e8be4d96dd4d3d9d6b12b778a5b8beee592a324a (patch) | |
tree | 55f62cf99ad937f2fc88240e8e2975521f7a7075 /tensorflow/python/eager | |
parent | 723242c800f237368e238fe03bd50516807e3402 (diff) |
Only start_step/end_step on GradientTape if executing eagerly.
This prevents creating a context where none is required.
PiperOrigin-RevId: 213500408
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r-- | tensorflow/python/eager/backprop.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 907234b0f8..50a6ce6324 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -725,7 +725,9 @@ class GradientTape(object): self._persistent = persistent self._watch_accessed_variables = watch_accessed_variables self._recording = False - context.context().start_step() + self._created_eagerly = context.executing_eagerly() + if self._created_eagerly: + context.context().start_step() def __enter__(self): """Enters a context inside which operations are recorded on this tape.""" @@ -755,7 +757,8 @@ class GradientTape(object): self._recording = False def __del__(self): - context.context().end_step() + if self._created_eagerly: + context.context().end_step() def watch(self, tensor): """Ensures that `tensor` is being traced by this tape. |