aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg_grad.py
blob: 893618c9dd792cb02ceb6b5b98bbdbd50f469e2c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""Gradients for operators defined in linalg_ops.py."""
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops

@ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad):
  """Gradient for MatrixInverse."""
  ainv = op.outputs[0]
  return -math_ops.matmul(
      ainv,
      math_ops.matmul(grad, ainv, transpose_b=True),
      transpose_a=True)

@ops.RegisterGradient("BatchMatrixInverse")
def _BatchMatrixInverseGrad(op, grad):
  """Gradient for BatchMatrixInverse."""
  ainv = op.outputs[0]
  return -math_ops.batch_matmul(
      ainv,
      math_ops.batch_matmul(grad, ainv, adj_y=True),
      adj_x=True)