aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index ebf17c8a41..b1f50fd341 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -512,6 +512,16 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("MaxPoolWithArgmax")
+def _MaxPoolGradWithArgmax(op, grad, unused_argmax_grad):
+ return gen_nn_ops._max_pool_grad_with_argmax(op.inputs[0],
+ grad,
+ op.outputs[1],
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"))
+
+
@ops.RegisterGradient("MaxPoolGrad")
def _MaxPoolGradGrad(op, grad):
return (array_ops.zeros(