diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 102 |
1 files changed, 94 insertions, 8 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index a01466e1ae..f5e9550b97 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -121,7 +121,7 @@ def _Conv3DBackpropFilterGrad(op, grad): @ops.RegisterGradient("AvgPool3D") def _AvgPool3DGrad(op, grad): - return nn_ops.avg_pool3d_grad( + return gen_nn_ops._avg_pool3d_grad( array_ops.shape(op.inputs[0]), grad, ksize=op.get_attr("ksize"), @@ -130,15 +130,58 @@ def _AvgPool3DGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("AvgPool3DGrad") +def _AvgPool3DGradGrad(op, grad): + return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops.avg_pool3d( + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + op.get_attr("padding"), + data_format=op.get_attr("data_format"))) + + @ops.RegisterGradient("MaxPool3D") def _MaxPool3DGrad(op, grad): - return nn_ops.max_pool3d_grad(op.inputs[0], - op.outputs[0], - grad, - ksize=op.get_attr("ksize"), - strides=op.get_attr("strides"), - padding=op.get_attr("padding"), - data_format=op.get_attr("data_format")) + return gen_nn_ops._max_pool3d_grad( + op.inputs[0], + op.outputs[0], + grad, + ksize=op.get_attr("ksize"), + strides=op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format")) + + +@ops.RegisterGradient("MaxPool3DGrad") +def _MaxPool3DGradGrad(op, grad): + return (array_ops.zeros( + shape=array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype), array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + gen_nn_ops._max_pool3d_grad_grad( + op.inputs[0], + op.inputs[1], + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format"))) + + +@ops.RegisterGradient("MaxPool3DGradGrad") +def _MaxPool3DGradGradGrad(op, grad): + return (array_ops.zeros( + shape=array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype), array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + gen_nn_ops._max_pool3d_grad( + op.inputs[0], + op.inputs[1], + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format"))) @ops.RegisterGradient("Softmax") @@ -214,6 +257,7 @@ def _BiasAddGrad(op, received_grad): return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad, data_format=data_format)) + @ops.RegisterGradient("BiasAddGrad") def _BiasAddGradGrad(op, received_grad): """Gradient for the BiasAddGrad op. @@ -438,6 +482,16 @@ def _AvgPoolGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("AvgPoolGrad") +def _AvgPoolGradGrad(op, grad): + return (array_ops.stop_gradient(op.inputs[0]), gen_nn_ops._avg_pool( + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + op.get_attr("padding"), + data_format=op.get_attr("data_format"))) + + @ops.RegisterGradient("MaxPool") def _MaxPoolGrad(op, grad): return gen_nn_ops._max_pool_grad(op.inputs[0], @@ -449,6 +503,38 @@ def _MaxPoolGrad(op, grad): data_format=op.get_attr("data_format")) +@ops.RegisterGradient("MaxPoolGrad") +def _MaxPoolGradGrad(op, grad): + return (array_ops.zeros( + shape=array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype), array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + gen_nn_ops._max_pool_grad_grad( + op.inputs[0], + op.inputs[1], + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format"))) + + +@ops.RegisterGradient("MaxPoolGradGrad") +def _MaxPoolGradGradGrad(op, grad): + return (array_ops.zeros( + shape=array_ops.shape(op.inputs[0]), + dtype=op.inputs[0].dtype), array_ops.zeros( + shape=array_ops.shape(op.inputs[1]), dtype=op.inputs[1].dtype), + gen_nn_ops._max_pool_grad( + op.inputs[0], + op.inputs[1], + grad, + op.get_attr("ksize"), + op.get_attr("strides"), + padding=op.get_attr("padding"), + data_format=op.get_attr("data_format"))) + + @ops.RegisterGradient("FractionalMaxPool") def _FractionalMaxPoolGrad(op, grad_0, unused_grad_1, unused_grad_2): """Returns gradient for FractionalMaxPool. |