aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-09-18 12:59:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-18 13:05:10 -0700
commite8be4d96dd4d3d9d6b12b778a5b8beee592a324a (patch)
tree55f62cf99ad937f2fc88240e8e2975521f7a7075 /tensorflow/python/eager
parent723242c800f237368e238fe03bd50516807e3402 (diff)
Only start_step/end_step on GradientTape if executing eagerly.
This prevents creating a context where none is required. PiperOrigin-RevId: 213500408
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/backprop.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 907234b0f8..50a6ce6324 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -725,7 +725,9 @@ class GradientTape(object):
self._persistent = persistent
self._watch_accessed_variables = watch_accessed_variables
self._recording = False
- context.context().start_step()
+ self._created_eagerly = context.executing_eagerly()
+ if self._created_eagerly:
+ context.context().start_step()
def __enter__(self):
"""Enters a context inside which operations are recorded on this tape."""
@@ -755,7 +757,8 @@ class GradientTape(object):
self._recording = False
def __del__(self):
- context.context().end_step()
+ if self._created_eagerly:
+ context.context().end_step()
def watch(self, tensor):
"""Ensures that `tensor` is being traced by this tape.