diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index ebf17c8a41..b1f50fd341 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -512,6 +512,16 @@ def _MaxPoolGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("MaxPoolWithArgmax") +def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad): + return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0], + grad, + op.outputs[1], + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding")) + + @ops.RegisterGradient("MaxPoolGrad") def _MaxPoolGradGrad(op, grad): return (array_ops.zeros( |