diff options
Diffstat (limited to 'tensorflow/python/eager/backprop.py')
-rw-r--r-- | tensorflow/python/eager/backprop.py | 16 |
1 files changed, 2 insertions, 14 deletions
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index 0a92ab38a8..86b3776b8c 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -727,24 +727,12 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_last_shape_dtype = [None, None] -_last_zero = [None] - - -def _zeros(shape, dtype): - """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" - if [shape, dtype] != _last_shape_dtype: - _last_shape_dtype[:] = [shape, dtype] - _last_zero[0] = array_ops.zeros(shape, dtype) - return _last_zero[0] - - _default_vspace = imperative_grad.VSpace( num_elements_fn=_num_elements, aggregate_fn=_aggregate_grads, tensor_id=ops.tensor_id, - zeros=_zeros, - ones=array_ops.ones) + zeros=array_ops.zeros, + ones_like=lambda x: ops.convert_to_tensor(array_ops.ones_like(x))) class GradientTape(object): |