aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py102
1 files changed, 94 insertions, 8 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a01466e1ae..f5e9550b97 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -121,7 +121,7 @@ def _Conv3DBackpropFilterGrad(op, grad):
@ops.RegisterGradient("AvgPool3D")
def _AvgPool3DGrad(op, grad):
- return nn_ops.avg_pool3d_grad(
+ return gen_nn_ops._avg_pool3d_grad(
array_ops.shape(op.inputs[0]),
grad,
ksize=op.get_attr("ksize"),
@@ -130,15 +130,58 @@ def _AvgPool3DGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("AvgPool3DGrad")
+def _AvgPool3DGradGrad(op, grad):
+ return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops.avg_pool3d(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
+
+
@ops.RegisterGradient("MaxPool3D")
def _MaxPool3DGrad(op, grad):
- return nn_ops.max_pool3d_grad(op.inputs[0],
- op.outputs[0],
- grad,
- ksize=op.get_attr("ksize"),
- strides=op.get_attr("strides"),
- padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format"))
+ return gen_nn_ops._max_pool3d_grad(
+ op.inputs[0],
+ op.outputs[0],
+ grad,
+ ksize=op.get_attr("ksize"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format"))
+
+
+@ops.RegisterGradient("MaxPool3DGrad")
+def _MaxPool3DGradGrad(op, grad):
+ return (array_ops.zeros(
+ shape=array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype), array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ gen_nn_ops._max_pool3d_grad_grad(
+ op.inputs[0],
+ op.inputs[1],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
+
+
+@ops.RegisterGradient("MaxPool3DGradGrad")
+def _MaxPool3DGradGradGrad(op, grad):
+ return (array_ops.zeros(
+ shape=array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype), array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ gen_nn_ops._max_pool3d_grad(
+ op.inputs[0],
+ op.inputs[1],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
@ops.RegisterGradient("Softmax")
@@ -214,6 +257,7 @@ def _BiasAddGrad(op, received_grad):
return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad,
data_format=data_format))
+
@ops.RegisterGradient("BiasAddGrad")
def _BiasAddGradGrad(op, received_grad):
"""Gradient for the BiasAddGrad op.
@@ -438,6 +482,16 @@ def _AvgPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("AvgPoolGrad")
+def _AvgPoolGradGrad(op, grad):
+ return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops._avg_pool(
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
+
+
@ops.RegisterGradient("MaxPool")
def _MaxPoolGrad(op, grad):
return gen_nn_ops._max_pool_grad(op.inputs[0],
@@ -449,6 +503,38 @@ def _MaxPoolGrad(op, grad):
data_format=op.get_attr("data_format"))
+@ops.RegisterGradient("MaxPoolGrad")
+def _MaxPoolGradGrad(op, grad):
+ return (array_ops.zeros(
+ shape=array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype), array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ gen_nn_ops._max_pool_grad_grad(
+ op.inputs[0],
+ op.inputs[1],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
+
+
+@ops.RegisterGradient("MaxPoolGradGrad")
+def _MaxPoolGradGradGrad(op, grad):
+ return (array_ops.zeros(
+ shape=array_ops.shape(op.inputs[0]),
+ dtype=op.inputs[0].dtype), array_ops.zeros(
+ shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype),
+ gen_nn_ops._max_pool_grad(
+ op.inputs[0],
+ op.inputs[1],
+ grad,
+ op.get_attr("ksize"),
+ op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ data_format=op.get_attr("data_format")))
+
+
@ops.RegisterGradient("FractionalMaxPool")
def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2):
"""Returns gradient for FractionalMaxPool.