aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/backprop.py')
-rw-r--r--tensorflow/python/eager/backprop.py35
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 3e3c82e56a..c59ad09bf1 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -599,15 +599,18 @@ def _fast_fill(value, shape, dtype):
def _zeros(shape, dtype):
- """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
- device = context.context().device_name
+ """Helper to return (possibly cached) zero tensors in eager mode."""
if dtype == dtypes.variant:
# TODO(apassos): need to save enough information about variant tensors to do
# a zeros
return None
- # pylint: disable=protected-access
- cache_key = shape, dtype, device, context.context()._eager_context.mode
- # pylint: enable=protected-access
+
+ ctx = context.context()
+ if not ctx.executing_eagerly():
+ return array_ops.zeros(shape, dtype)
+
+ device = ctx.device_name
+ cache_key = shape, dtype, device
cached = _zeros_cache.get(cache_key)
if cached is None:
cached = _fast_fill(0, shape, dtype)
@@ -616,6 +619,9 @@ def _zeros(shape, dtype):
def _ones(shape, dtype):
+ if not context.context().executing_eagerly():
+ return array_ops.ones(shape, dtype)
+
if shape == (): # pylint: disable=g-explicit-bool-comparison
return constant_op.constant(1, dtype=dtype)
return _fast_fill(1, shape, dtype)
@@ -643,10 +649,10 @@ class GradientTape(object):
Operations are recorded if they are executed within this context manager and
at least one of their inputs is being "watched".
- Trainable variables (created by `tf.contrib.eager.Variable` or
- @{tf.get_variable}, trainable=True is default in both cases) are automatically
- watched. Tensors can be manually watched by invoking the `watch` method on
- this context manager.
+ Trainable variables (created by `tf.Variable` or @{tf.get_variable},
+ trainable=True is default in both cases) are automatically watched. Tensors
+ can be manually watched by invoking the `watch` method on this context
+ manager.
For example, consider the function `y = x * x`. The gradient at `x = 3.0` can
be computed as:
@@ -713,10 +719,15 @@ class GradientTape(object):
if self._recording:
self._pop_tape()
- def _push_tape(self):
+ def _push_tape(self, existing_tape=False):
if self._recording:
raise ValueError("Tape is already recording.")
- self._tape = tape.push_new_tape(persistent=self._persistent)
+ if existing_tape:
+ if self._tape is None:
+ raise ValueError("There is no existing tape.")
+ tape.push_tape(self._tape)
+ else:
+ self._tape = tape.push_new_tape(persistent=self._persistent)
self._recording = True
def _pop_tape(self):
@@ -764,7 +775,7 @@ class GradientTape(object):
try:
yield
finally:
- self._push_tape()
+ self._push_tape(existing_tape=True)
def reset(self):
"""Clears all information stored in this tape.