blob: 893618c9dd792cb02ceb6b5b98bbdbd50f469e2c (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
|
"""Gradients for operators defined in linalg_ops.py."""
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ops.RegisterGradient("MatrixInverse")
def _MatrixInverseGrad(op, grad):
"""Gradient for MatrixInverse."""
ainv = op.outputs[0]
return -math_ops.matmul(
ainv,
math_ops.matmul(grad, ainv, transpose_b=True),
transpose_a=True)
@ops.RegisterGradient("BatchMatrixInverse")
def _BatchMatrixInverseGrad(op, grad):
"""Gradient for BatchMatrixInverse."""
ainv = op.outputs[0]
return -math_ops.batch_matmul(
ainv,
math_ops.batch_matmul(grad, ainv, adj_y=True),
adj_x=True)
|