diff options
Diffstat (limited to 'tensorflow/python/ops/nn_grad.py')
-rw-r--r-- | tensorflow/python/ops/nn_grad.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py index a0a4570f17..180a396adc 100644 --- a/tensorflow/python/ops/nn_grad.py +++ b/tensorflow/python/ops/nn_grad.py @@ -203,6 +203,43 @@ def _BiasAddGrad(op, received_grad): return (received_grad, gen_nn_ops.bias_add_grad(out_backprop=received_grad, data_format=data_format)) +@ops.RegisterGradient("BiasAddGrad") +def _BiasAddGradGrad(op, received_grad): + """Gradient for the BiasAddGrad op. + + Args: + op: BiasAddGrad op for which we are calculating gradients. + received_grad: The gradients passed to the BiasAddGrad op. + + Returns: + A single gradient Tensor for the input to BiasAddGrad (which + is the gradient of the bias term in BiasAdd) + """ + + try: + data_format = op.get_attr("data_format") + except ValueError: + data_format = None + + shape = array_ops.shape(op.inputs[0]) + rank = array_ops.rank(op.inputs[0]) + bias_shape = array_ops.shape(received_grad) + + if data_format == "NCHW": + expanded_shape = array_ops.concat( + 0, + [array_ops.ones_like(shape[:-3]), bias_shape, array_ops.ones_like(shape[-2:])] + ) + + tile_mults = array_ops.concat(0, [shape[:-3], [1], shape[-2:]]) + + else: + expanded_shape = array_ops.concat(0, [array_ops.ones_like(shape[:-1]), bias_shape]) + tile_mults = array_ops.concat(0, [shape[:-1], [1]]) + + expanded_grad = array_ops.reshape(received_grad, expanded_shape) + return array_ops.tile(expanded_grad, tile_mults) + @ops.RegisterGradient("BiasAddV1") def _BiasAddGradV1(unused_bias_op, received_grad): |