diff options
Diffstat (limited to 'tensorflow/python/ops/array_grad.py')
-rw-r--r-- | tensorflow/python/ops/array_grad.py | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py index 3c6a5c9e56..57d2657838 100644 --- a/tensorflow/python/ops/array_grad.py +++ b/tensorflow/python/ops/array_grad.py @@ -255,10 +255,15 @@ def _SliceGrad(op, grad): @ops.RegisterGradient("StridedSlice") def _StridedSliceGrad(op, grad): """Gradient for StridedSlice op.""" - x = array_ops.shape(op.inputs[0]) begin = op.inputs[1] end = op.inputs[2] strides = op.inputs[3] + # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the + # same dtype so we build a shape of the same type as other args. + # Note that the choice of `begin` for specifying `out_type` is arbitrary. + # We could choose any of {begin|end|strides}.dtype since they are required to + # be the same. + x = array_ops.shape(op.inputs[0], out_type=begin.dtype) return array_ops.strided_slice_grad( x, |