diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 13:20:32 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 13:20:32 -0700 |
commit | 229652512f745b9ef110c02deb9dce98fd33b5be (patch) | |
tree | f3cb7f977ecbc42d381d4ae20e8fc3f77bcb02fb | |
parent | c09e01232ace0a5657828c9c47de942751c94949 (diff) | |
parent | b07f8211409f2b2e46ab539291e824f2b7865885 (diff) |
Merge pull request #21436 from carusyte:dev_latest
PiperOrigin-RevId: 210418993
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index a648653909..e1a01ab4c3 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -27,7 +27,6 @@ from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import sparse_ops @ops.RegisterGradient("Conv2DBackpropInput") @@ -977,25 +976,30 @@ def _TopKGrad(op, grad, _): in_shape = array_ops.shape(op.inputs[0]) ind_shape = array_ops.shape(op.outputs[1]) - ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1) + # int32 is not supported on GPU hence up-casting + ind_lastdim = array_ops.gather(math_ops.cast( + ind_shape, dtypes.int64), array_ops.size(ind_shape) - 1) # Flatten indices to 2D. ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) - in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1) + in_lastdim = array_ops.gather(math_ops.cast( + in_shape, dtypes.int64), array_ops.size(in_shape) - 1) outerdim = array_ops.shape(ind_2d)[0] # Compute linear indices (flattened to 1D). - ind = array_ops.reshape(ind_2d + array_ops.expand_dims( - math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1]) + ind = array_ops.reshape(ind_2d + math_ops.cast(array_ops.expand_dims( + math_ops.range(0, math_ops.cast(outerdim, dtypes.int64) + * in_lastdim, in_lastdim), -1), dtypes.int32), [-1]) # Substitute grad to appropriate locations and fill the rest with zeros, # finally reshaping it to the original input shape. return [ array_ops.reshape( - sparse_ops.sparse_to_dense( - ind, - array_ops.reshape(math_ops.reduce_prod(in_shape), [1]), + array_ops.scatter_nd( + array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), - validate_indices=False), in_shape), + [math_ops.reduce_prod(in_shape)] + ), + in_shape), array_ops.zeros([], dtype=dtypes.int32) ] |