aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/qr_op_test.py66
-rw-r--r--tensorflow/python/ops/linalg_grad.py42
2 files changed, 98 insertions, 10 deletions
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index b4fd89bd03..8848c15e76 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
@@ -140,11 +141,11 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
x_reshape = np.reshape(x_np, (-1, x_np.shape[-2], x_np.shape[-1]))
for i in range(new_first_dim):
if full_matrices_:
- np_q_reshape[i,:,:], _ = \
- np.linalg.qr(x_reshape[i,:,:], mode="complete")
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="complete")
else:
- np_q_reshape[i,:,:], _ = \
- np.linalg.qr(x_reshape[i,:,:], mode="reduced")
+ np_q_reshape[i, :, :], _ = np.linalg.qr(
+ x_reshape[i, :, :], mode="reduced")
np_q = np.reshape(np_q_reshape, q_dims)
CompareOrthogonal(self, np_q, q_tf_val, min(shape_[-2:]))
CheckApproximation(self, x_np, q_tf_val, r_tf_val)
@@ -153,6 +154,46 @@ def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
return Test
+class QrGradOpTest(test.TestCase):
+ pass
+
+
+def _GetQrGradOpTest(dtype_, shape_, full_matrices_):
+
+ def Test(self):
+ np.random.seed(42)
+ a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
+ if dtype_ in [np.complex64, np.complex128]:
+ a += 1j * np.random.uniform(
+ low=-1.0, high=1.0, size=shape_).astype(dtype_)
+ # Optimal stepsize for central difference is O(epsilon^{1/3}).
+ epsilon = np.finfo(dtype_).eps
+ delta = 0.1 * epsilon**(1.0 / 3.0)
+ if dtype_ in [np.float32, np.complex64]:
+ tol = 3e-2
+ else:
+ tol = 1e-6
+ with self.test_session(use_gpu=True):
+ tf_a = constant_op.constant(a)
+ tf_b = linalg_ops.qr(tf_a, full_matrices=full_matrices_)
+ for b in tf_b:
+ x_init = np.random.uniform(
+ low=-1.0, high=1.0, size=shape_).astype(dtype_)
+ if dtype_ in [np.complex64, np.complex128]:
+ x_init += 1j * np.random.uniform(
+ low=-1.0, high=1.0, size=shape_).astype(dtype_)
+ theoretical, numerical = gradient_checker.compute_gradient(
+ tf_a,
+ tf_a.get_shape().as_list(),
+ b,
+ b.get_shape().as_list(),
+ x_init_value=x_init,
+ delta=delta)
+ self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
+
+ return Test
+
+
if __name__ == "__main__":
for dtype in np.float32, np.float64, np.complex64, np.complex128:
for rows in 1, 2, 5, 10, 32, 100:
@@ -168,4 +209,21 @@ if __name__ == "__main__":
_AddTest(QrOpTest, "Qr", name,
_GetQrOpTest(dtype, shape, full_matrices,
use_static_shape))
+
+ # TODO(pfau): Get working with complex types.
+ # TODO(pfau): Get working with full_matrices when rows != cols
+ # TODO(pfau): Get working when rows < cols
+ # TODO(pfau): Get working with shapeholders (dynamic shapes)
+ for full_matrices in False, True:
+ for dtype in np.float32, np.float64:
+ for rows in 1, 2, 5, 10:
+ for cols in 1, 2, 5, 10:
+ if rows == cols or (not full_matrices and rows > cols):
+ for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
+ shape = batch_dims + (rows, cols)
+ name = "%s_%s_full_%s" % (dtype.__name__,
+ "_".join(map(str, shape)),
+ full_matrices)
+ _AddTest(QrGradOpTest, "QrGrad", name,
+ _GetQrGradOpTest(dtype, shape, full_matrices))
test.main()
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index ec263591e1..8a76fe3ce5 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -81,6 +81,36 @@ def _CholeskyGrad(op, grad):
return grad_a * 0.5
+@ops.RegisterGradient("Qr")
+def _QrGrad(op, dq, dr):
+ """Gradient for Qr."""
+ q, r = op.outputs
+ if q.dtype.is_complex:
+ raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype)
+ if (r.shape.ndims is None or r.shape.as_list()[-2] is None or
+ r.shape.as_list()[-1] is None):
+ raise NotImplementedError("QrGrad not implemented with dynamic shapes.")
+ if r.shape[-2].value != r.shape[-1].value:
+ raise NotImplementedError("QrGrad not implemented when ncols > nrows "
+ "or full_matrices is true and ncols != nrows.")
+
+ qdq = math_ops.matmul(q, dq, adjoint_a=True)
+ qdq_ = qdq - _linalg.adjoint(qdq)
+ rdr = math_ops.matmul(r, dr, adjoint_b=True)
+ rdr_ = rdr - _linalg.adjoint(rdr)
+ tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0)
+
+ def _TriangularSolve(x, r):
+ """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri."""
+ return _linalg.adjoint(
+ linalg_ops.matrix_triangular_solve(
+ r, _linalg.adjoint(x), lower=False, adjoint=False))
+
+ grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r))
+ grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r)
+ return grad_a + grad_b
+
+
@ops.RegisterGradient("MatrixSolve")
def _MatrixSolveGrad(op, grad):
"""Gradient for MatrixSolve."""
@@ -105,7 +135,7 @@ def _MatrixSolveLsGrad(op, grad):
# b) Implement a symmetric rank-k update op instead of computing
# x*z + transpose(x*z). This pattern occurs other places in TensorFlow.
- def _overdetermined(op, grad):
+ def _Overdetermined(op, grad):
"""Gradients for the overdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the first
@@ -130,7 +160,7 @@ def _MatrixSolveLsGrad(op, grad):
grad_b = math_ops.matmul(a, z)
return (grad_a, grad_b, None)
- def _underdetermined(op, grad):
+ def _Underdetermined(op, grad):
"""Gradients for the underdetermined case of MatrixSolveLs.
This is the backprop for the solution to the normal equations of the second
@@ -162,16 +192,16 @@ def _MatrixSolveLsGrad(op, grad):
matrix_shape = op.inputs[0].get_shape()[-2:]
if matrix_shape.is_fully_defined():
if matrix_shape[-2] >= matrix_shape[-1]:
- return _overdetermined(op, grad)
+ return _Overdetermined(op, grad)
else:
- return _underdetermined(op, grad)
+ return _Underdetermined(op, grad)
else:
# We have to defer determining the shape to runtime and use
# conditional execution of the appropriate graph.
matrix_shape = array_ops.shape(op.inputs[0])[-2:]
return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1],
- lambda: _overdetermined(op, grad),
- lambda: _underdetermined(op, grad))
+ lambda: _Overdetermined(op, grad),
+ lambda: _Underdetermined(op, grad))
@ops.RegisterGradient("MatrixTriangularSolve")