aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 13:20:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 13:20:32 -0700
commit229652512f745b9ef110c02deb9dce98fd33b5be (patch)
treef3cb7f977ecbc42d381d4ae20e8fc3f77bcb02fb
parentc09e01232ace0a5657828c9c47de942751c94949 (diff)
parentb07f8211409f2b2e46ab539291e824f2b7865885 (diff)
Merge pull request #21436 from carusyte:dev_latest
PiperOrigin-RevId: 210418993
-rw-r--r--tensorflow/python/ops/nn_grad.py22
1 files changed, 13 insertions, 9 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a648653909..e1a01ab4c3 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -27,7 +27,6 @@ from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
-from tensorflow.python.ops import sparse_ops
@ops.RegisterGradient("Conv2DBackpropInput")
@@ -977,25 +976,30 @@ def _TopKGrad(op, grad, _):
in_shape = array_ops.shape(op.inputs[0])
ind_shape = array_ops.shape(op.outputs[1])
- ind_lastdim = array_ops.gather(ind_shape, array_ops.size(ind_shape) - 1)
+ # int32 is not supported on GPU hence up-casting
+ ind_lastdim = array_ops.gather(math_ops.cast(
+ ind_shape, dtypes.int64), array_ops.size(ind_shape) - 1)
# Flatten indices to 2D.
ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim]))
- in_lastdim = array_ops.gather(in_shape, array_ops.size(in_shape) - 1)
+ in_lastdim = array_ops.gather(math_ops.cast(
+ in_shape, dtypes.int64), array_ops.size(in_shape) - 1)
outerdim = array_ops.shape(ind_2d)[0]
# Compute linear indices (flattened to 1D).
- ind = array_ops.reshape(ind_2d + array_ops.expand_dims(
- math_ops.range(0, outerdim * in_lastdim, in_lastdim), -1), [-1])
+ ind = array_ops.reshape(ind_2d + math_ops.cast(array_ops.expand_dims(
+ math_ops.range(0, math_ops.cast(outerdim, dtypes.int64)
+ * in_lastdim, in_lastdim), -1), dtypes.int32), [-1])
# Substitute grad to appropriate locations and fill the rest with zeros,
# finally reshaping it to the original input shape.
return [
array_ops.reshape(
- sparse_ops.sparse_to_dense(
- ind,
- array_ops.reshape(math_ops.reduce_prod(in_shape), [1]),
+ array_ops.scatter_nd(
+ array_ops.expand_dims(ind, -1),
array_ops.reshape(grad, [-1]),
- validate_indices=False), in_shape),
+ [math_ops.reduce_prod(in_shape)]
+ ),
+ in_shape),
array_ops.zeros([], dtype=dtypes.int32)
]