aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg_grad.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/linalg_grad.py')
-rw-r--r--tensorflow/python/ops/linalg_grad.py25
1 files changed, 25 insertions, 0 deletions
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
new file mode 100644
index 0000000000..893618c9dd
--- /dev/null
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -0,0 +1,25 @@
+"""Gradients for operators defined in linalg_ops.py."""
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import constant_op
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+@ops.RegisterGradient("MatrixInverse")
+def _MatrixInverseGrad(op, grad):
+ """Gradient for MatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.matmul(
+ ainv,
+ math_ops.matmul(grad, ainv, transpose_b=True),
+ transpose_a=True)
+
+@ops.RegisterGradient("BatchMatrixInverse")
+def _BatchMatrixInverseGrad(op, grad):
+ """Gradient for BatchMatrixInverse."""
+ ainv = op.outputs[0]
+ return -math_ops.batch_matmul(
+ ainv,
+ math_ops.batch_matmul(grad, ainv, adj_y=True),
+ adj_x=True)