aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/array_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/array_grad.py')
-rw-r--r--tensorflow/python/ops/array_grad.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/tensorflow/python/ops/array_grad.py b/tensorflow/python/ops/array_grad.py
index 3c6a5c9e56..57d2657838 100644
--- a/tensorflow/python/ops/array_grad.py
+++ b/tensorflow/python/ops/array_grad.py
@@ -255,10 +255,15 @@ def _SliceGrad(op, grad):
@ops.RegisterGradient("StridedSlice")
def _StridedSliceGrad(op, grad):
"""Gradient for StridedSlice op."""
- x = array_ops.shape(op.inputs[0])
begin = op.inputs[1]
end = op.inputs[2]
strides = op.inputs[3]
+ # StridedSliceGrad requires `x`, `begin`, `end` and `strides` to be of the
+ # same dtype so we build a shape of the same type as other args.
+ # Note that the choice of `begin` for specifying `out_type` is arbitrary.
+ # We could choose any of {begin|end|strides}.dtype since they are required to
+ # be the same.
+ x = array_ops.shape(op.inputs[0], out_type=begin.dtype)
return array_ops.strided_slice_grad(
x,