aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager/backprop.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/eager/backprop.py')
-rw-r--r--tensorflow/python/eager/backprop.py16
1 files changed, 2 insertions, 14 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 0a92ab38a8..86b3776b8c 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -727,24 +727,12 @@ def _num_elements(grad):
raise ValueError("`grad` not a Tensor or IndexedSlices.")
-_last_shape_dtype = [None, None]
-_last_zero = [None]
-
-
-def _zeros(shape, dtype):
- """Wraps array_ops.zeros to cache last zero for a given shape and dtype."""
- if [shape, dtype] != _last_shape_dtype:
- _last_shape_dtype[:] = [shape, dtype]
- _last_zero[0] = array_ops.zeros(shape, dtype)
- return _last_zero[0]
-
-
_default_vspace = imperative_grad.VSpace(
num_elements_fn=_num_elements,
aggregate_fn=_aggregate_grads,
tensor_id=ops.tensor_id,
- zeros=_zeros,
- ones=array_ops.ones)
+ zeros=array_ops.zeros,
+ ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x)))
class GradientTape(object):