aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r--tensorflow/python/ops/nn_grad.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index a0a4570f17..180a396adc 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -203,6 +203,43 @@ def _BiasAddGrad(op, received_grad):
return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad,
data_format=data_format))
+@ops.RegisterGradient("BiasAddGrad")
+def _BiasAddGradGrad(op, received_grad):
+ """Gradient for the BiasAddGrad op.
+
+ Args:
+ op: BiasAddGrad op for which we are calculating gradients.
+ received_grad: The gradients passed to the BiasAddGrad op.
+
+ Returns:
+ A single gradient Tensor for the input to BiasAddGrad (which
+ is the gradient of the bias term in BiasAdd)
+ """
+
+ try:
+ data_format = op.get_attr("data_format")
+ except ValueError:
+ data_format = None
+
+ shape = array_ops.shape(op.inputs[0])
+ rank = array_ops.rank(op.inputs[0])
+ bias_shape = array_ops.shape(received_grad)
+
+ if data_format == "NCHW":
+ expanded_shape = array_ops.concat(
+ 0,
+ [array_ops.ones_like(shape[:-3]), bias_shape, array_ops.ones_like(shape[-2:])]
+ )
+
+ tile_mults = array_ops.concat(0, [shape[:-3], [1], shape[-2:]])
+
+ else:
+ expanded_shape = array_ops.concat(0, [array_ops.ones_like(shape[:-1]), bias_shape])
+ tile_mults = array_ops.concat(0, [shape[:-1], [1]])
+
+ expanded_grad = array_ops.reshape(received_grad, expanded_shape)
+ return array_ops.tile(expanded_grad, tile_mults)
+
@ops.RegisterGradient("BiasAddV1")
def _BiasAddGradV1(unused_bias_op, received_grad):