diff options
Diffstat (limited to 'tensorflow/python/eager/backprop.py')
-rw-r--r-- | tensorflow/python/eager/backprop.py | 35 |
1 files changed, 23 insertions, 12 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 3e3c82e56a..c59ad09bf1 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -599,15 +599,18 @@ def _fast_fill(value, shape, dtype): def _zeros(shape, dtype): - """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" - device = context.context().device_name + """Helper to return (possibly cached) zero tensors in eager mode.""" if dtype == dtypes.variant: # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - # pylint: disable=protected-access - cache_key = shape, dtype, device, context.context()._eager_context.mode - # pylint: enable=protected-access + + ctx = context.context() + if not ctx.executing_eagerly(): + return array_ops.zeros(shape, dtype) + + device = ctx.device_name + cache_key = shape, dtype, device cached = _zeros_cache.get(cache_key) if cached is None: cached = _fast_fill(0, shape, dtype) @@ -616,6 +619,9 @@ def _zeros(shape, dtype): def _ones(shape, dtype): + if not context.context().executing_eagerly(): + return array_ops.ones(shape, dtype) + if shape == (): # pylint: disable=g-explicit-bool-comparison return constant_op.constant(1, dtype=dtype) return _fast_fill(1, shape, dtype) @@ -643,10 +649,10 @@ class GradientTape(object): Operations are recorded if they are executed within this context manager and at least one of their inputs is being "watched". - Trainable variables (created by `tf.contrib.eager.Variable` or - @{tf.get_variable}, trainable=True is default in both cases) are automatically - watched. Tensors can be manually watched by invoking the `watch` method on - this context manager. + Trainable variables (created by `tf.Variable` or @{tf.get_variable}, + trainable=True is default in both cases) are automatically watched. Tensors + can be manually watched by invoking the `watch` method on this context + manager. For example, consider the function `y = x * x`. The gradient at `x = 3.0` can be computed as: @@ -713,10 +719,15 @@ class GradientTape(object): if self._recording: self._pop_tape() - def _push_tape(self): + def _push_tape(self, existing_tape=False): if self._recording: raise ValueError("Tape is already recording.") - self._tape = tape.push_new_tape(persistent=self._persistent) + if existing_tape: + if self._tape is None: + raise ValueError("There is no existing tape.") + tape.push_tape(self._tape) + else: + self._tape = tape.push_new_tape(persistent=self._persistent) self._recording = True def _pop_tape(self): @@ -764,7 +775,7 @@ class GradientTape(object): try: yield finally: - self._push_tape() + self._push_tape(existing_tape=True) def reset(self): """Clears all information stored in this tape. |