# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Gradients for operators defined in linalg_ops.py. Useful reference for derivative formulas is An extended collection of matrix derivative results for forward and reverse mode algorithmic differentiation by Mike Giles: http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf A detailed derivation of formulas for backpropagating through spectral layers (SVD and Eig) by Ionescu, Vantzos & Sminchisescu: https://arxiv.org/pdf/1509.07838v4.pdf """ from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linalg_impl as _linalg @ops.RegisterGradient("MatrixInverse") def _MatrixInverseGrad(op, grad): """Gradient for MatrixInverse.""" ainv = op.outputs[0] return -math_ops.matmul( ainv, math_ops.matmul(grad, ainv, adjoint_b=True), adjoint_a=True) @ops.RegisterGradient("MatrixDeterminant") def _MatrixDeterminantGrad(op, grad): """Gradient for MatrixDeterminant.""" a = op.inputs[0] c = op.outputs[0] a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) multipliers = array_ops.reshape(grad * c, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) return multipliers * a_adj_inv @ops.RegisterGradient("LogMatrixDeterminant") def _LogMatrixDeterminantGrad(op, _, grad_b): """Gradient for LogMatrixDeterminant.""" a = op.inputs[0] c = op.outputs[1] a_adj_inv = linalg_ops.matrix_inverse(a, adjoint=True) multipliers = array_ops.reshape( grad_b, array_ops.concat([array_ops.shape(c), [1, 1]], 0)) return multipliers * a_adj_inv @ops.RegisterGradient("Cholesky") def _CholeskyGrad(op, grad): """Gradient for Cholesky.""" # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1} l = op.outputs[0] num_rows = array_ops.shape(l)[-1] batch_shape = array_ops.shape(l)[:-2] l_inverse = linalg_ops.matrix_triangular_solve(l, linalg_ops.eye( num_rows, batch_shape=batch_shape, dtype=l.dtype)) middle = math_ops.matmul(l, grad, adjoint_a=True) middle = array_ops.matrix_set_diag(middle, 0.5 * array_ops.matrix_diag_part(middle)) middle = array_ops.matrix_band_part(middle, -1, 0) grad_a = math_ops.matmul( math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse) grad_a += _linalg.adjoint(grad_a) return grad_a * 0.5 @ops.RegisterGradient("Qr") def _QrGrad(op, dq, dr): """Gradient for Qr.""" q, r = op.outputs if q.dtype.is_complex: raise NotImplementedError("QrGrad not implemented for dtype: %s" % q.dtype) if (r.shape.ndims is None or r.shape.as_list()[-2] is None or r.shape.as_list()[-1] is None): raise NotImplementedError("QrGrad not implemented with dynamic shapes.") if r.shape[-2].value != r.shape[-1].value: raise NotImplementedError("QrGrad not implemented when ncols > nrows " "or full_matrices is true and ncols != nrows.") qdq = math_ops.matmul(q, dq, adjoint_a=True) qdq_ = qdq - _linalg.adjoint(qdq) rdr = math_ops.matmul(r, dr, adjoint_b=True) rdr_ = rdr - _linalg.adjoint(rdr) tril = array_ops.matrix_band_part(qdq_ + rdr_, -1, 0) def _TriangularSolve(x, r): """Equiv to matmul(x, adjoint(matrix_inverse(r))) if r is upper-tri.""" return _linalg.adjoint( linalg_ops.matrix_triangular_solve( r, _linalg.adjoint(x), lower=False, adjoint=False)) grad_a = math_ops.matmul(q, dr + _TriangularSolve(tril, r)) grad_b = _TriangularSolve(dq - math_ops.matmul(q, qdq), r) return grad_a + grad_b @ops.RegisterGradient("MatrixSolve") def _MatrixSolveGrad(op, grad): """Gradient for MatrixSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") c = op.outputs[0] grad_b = linalg_ops.matrix_solve(a, grad, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) else: grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) return (grad_a, grad_b) @ops.RegisterGradient("MatrixSolveLs") def _MatrixSolveLsGrad(op, grad): """Gradients for MatrixSolveLs.""" # TODO(rmlarsen): The implementation could be more efficient: # a) Output the Cholesky factorization from forward op instead of # recomputing it here. # b) Implement a symmetric rank-k update op instead of computing # x*z + transpose(x*z). This pattern occurs other places in TensorFlow. def _Overdetermined(op, grad): """Gradients for the overdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the first kind: X = F(A, B) = (A^T * A + lambda * I)^{-1} * A^T * B which solve the least squares problem min ||A * X - B||_F^2 + lambda ||X||_F^2. """ a = op.inputs[0] b = op.inputs[1] x = op.outputs[0] l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) # pylint: disable=protected-access chol = linalg_ops._RegularizedGramianCholesky( a, l2_regularizer=l2_regularizer, first_kind=True) # pylint: enable=protected-access # Temporary z = (A^T * A + lambda * I)^{-1} * grad. z = linalg_ops.cholesky_solve(chol, grad) xzt = math_ops.matmul(x, z, adjoint_b=True) zx_sym = xzt + array_ops.matrix_transpose(xzt) grad_a = -math_ops.matmul(a, zx_sym) + math_ops.matmul(b, z, adjoint_b=True) grad_b = math_ops.matmul(a, z) return (grad_a, grad_b, None) def _Underdetermined(op, grad): """Gradients for the underdetermined case of MatrixSolveLs. This is the backprop for the solution to the normal equations of the second kind: X = F(A, B) = A * (A*A^T + lambda*I)^{-1} * B that (for lambda=0) solve the least squares problem min ||X||_F subject to A*X = B. """ a = op.inputs[0] b = op.inputs[1] l2_regularizer = math_ops.cast(op.inputs[2], a.dtype.base_dtype) # pylint: disable=protected-access chol = linalg_ops._RegularizedGramianCholesky( a, l2_regularizer=l2_regularizer, first_kind=False) # pylint: enable=protected-access grad_b = linalg_ops.cholesky_solve(chol, math_ops.matmul(a, grad)) # Temporary tmp = (A * A^T + lambda * I)^{-1} * B. tmp = linalg_ops.cholesky_solve(chol, b) a1 = math_ops.matmul(tmp, a, adjoint_a=True) a1 = -math_ops.matmul(grad_b, a1) a2 = grad - math_ops.matmul(a, grad_b, adjoint_a=True) a2 = math_ops.matmul(tmp, a2, adjoint_b=True) grad_a = a1 + a2 return (grad_a, grad_b, None) fast = op.get_attr("fast") if fast is False: raise ValueError("Gradient not defined for fast=False") matrix_shape = op.inputs[0].get_shape()[-2:] if matrix_shape.is_fully_defined(): if matrix_shape[-2] >= matrix_shape[-1]: return _Overdetermined(op, grad) else: return _Underdetermined(op, grad) else: # We have to defer determining the shape to runtime and use # conditional execution of the appropriate graph. matrix_shape = array_ops.shape(op.inputs[0])[-2:] return control_flow_ops.cond(matrix_shape[-2] >= matrix_shape[-1], lambda: _Overdetermined(op, grad), lambda: _Underdetermined(op, grad)) @ops.RegisterGradient("MatrixTriangularSolve") def _MatrixTriangularSolveGrad(op, grad): """Gradient for MatrixTriangularSolve.""" a = op.inputs[0] adjoint_a = op.get_attr("adjoint") lower_a = op.get_attr("lower") c = op.outputs[0] grad_b = linalg_ops.matrix_triangular_solve( a, grad, lower=lower_a, adjoint=not adjoint_a) if adjoint_a: grad_a = -math_ops.matmul(c, grad_b, adjoint_b=True) else: grad_a = -math_ops.matmul(grad_b, c, adjoint_b=True) if lower_a: grad_a = array_ops.matrix_band_part(grad_a, -1, 0) else: grad_a = array_ops.matrix_band_part(grad_a, 0, -1) return (grad_a, grad_b) @ops.RegisterGradient("SelfAdjointEigV2") def _SelfAdjointEigV2Grad(op, grad_e, grad_v): """Gradient for SelfAdjointEigV2.""" e = op.outputs[0] compute_v = op.get_attr("compute_v") # a = op.inputs[0], which satisfies # a[...,:,:] * v[...,:,i] = e[...,i] * v[...,i] with ops.control_dependencies([grad_e, grad_v]): if compute_v: v = op.outputs[1] # Construct the matrix f(i,j) = (i != j ? 1 / (e_i - e_j) : 0). # Notice that because of the term involving f, the gradient becomes # infinite (or NaN in practice) when eigenvalues are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate eigenvalues, the corresponding eigenvectors are only defined # up to arbitrary rotation in a (k-dimensional) subspace. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(e, -2) - array_ops.expand_dims(e, -1)), array_ops.zeros_like(e)) grad_a = math_ops.matmul( v, math_ops.matmul( array_ops.matrix_diag(grad_e) + f * math_ops.matmul(v, grad_v, adjoint_a=True), v, adjoint_b=True)) else: _, v = linalg_ops.self_adjoint_eig(op.inputs[0]) grad_a = math_ops.matmul(v, math_ops.matmul( array_ops.matrix_diag(grad_e), v, adjoint_b=True)) # The forward op only depends on the lower triangular part of a, so here we # symmetrize and take the lower triangle grad_a = array_ops.matrix_band_part(grad_a + _linalg.adjoint(grad_a), -1, 0) grad_a = array_ops.matrix_set_diag(grad_a, 0.5 * array_ops.matrix_diag_part(grad_a)) return grad_a @ops.RegisterGradient("Svd") def _SvdGrad(op, grad_s, grad_u, grad_v): """Gradient for the singular value decomposition.""" # The derivation for the compute_uv=False case, and most of # the derivation for the full_matrices=True case, are in # Giles' paper (see reference at top of file). A derivation for # the full_matrices=False case is available at # https://j-towns.github.io/papers/svd-derivative.pdf a = op.inputs[0] a_shape = a.get_shape().with_rank_at_least(2) grad_s_mat = array_ops.matrix_diag(grad_s) if not op.get_attr("compute_uv"): s, u, v = linalg_ops.svd(a, compute_uv=True) grad_a = math_ops.matmul(u, math_ops.matmul(grad_s_mat, v, adjoint_b=True)) grad_a.set_shape(a_shape) return grad_a full_matrices = op.get_attr("full_matrices") # TODO(rmlarsen): Make this work with complex types. if a.dtype.is_complex: raise NotImplementedError( "SVD gradient is not implemented for complex types and " "compute_uv=True.") grad_u_shape = grad_u.get_shape().with_rank_at_least(2) grad_v_shape = grad_v.get_shape().with_rank_at_least(2) m = a_shape[-2].merge_with(grad_u_shape[-2]) n = a_shape[-1].merge_with(grad_v_shape[-2]) batch_shape = a_shape[:-2].merge_with(grad_u_shape[:-2]).merge_with( grad_v_shape[:-2]) a_shape = batch_shape.concatenate([m, n]) m = a_shape[-2].value n = a_shape[-1].value # TODO(rmlarsen): Make this work with placeholders. if m is None or n is None: raise NotImplementedError( "SVD gradient has not been implemented for input with unknown " "inner matrix shape.") s = op.outputs[0] u = op.outputs[1] v = op.outputs[2] use_adjoint = False if m > n: # Compute the gradient for A^H = V * S^T * U^H, and (implicitly) take the # Hermitian transpose of the gradient at the end. use_adjoint = True m, n = n, m u, v = v, u grad_u, grad_v = grad_v, grad_u with ops.control_dependencies([grad_s, grad_u, grad_v]): if full_matrices and abs(m - n) > 1: raise NotImplementedError( "svd gradient is not implemented for abs(m - n) > 1 " "when full_matrices is True") s_mat = array_ops.matrix_diag(s) s2 = math_ops.square(s) # NOTICE: Because of the term involving f, the gradient becomes # infinite (or NaN in practice) when singular values are not unique. # Mathematically this should not be surprising, since for (k-fold) # degenerate singular values, the corresponding singular vectors are # only defined up a (k-dimensional) subspace. In practice, this can # lead to numerical instability when singular values are close but not # exactly equal. f = array_ops.matrix_set_diag( math_ops.reciprocal( array_ops.expand_dims(s2, -2) - array_ops.expand_dims(s2, -1)), array_ops.zeros_like(s)) s_inv_mat = array_ops.matrix_diag(math_ops.reciprocal(s)) v1 = v[..., :, :m] grad_v1 = grad_v[..., :, :m] u_gu = math_ops.matmul(u, grad_u, adjoint_a=True) v_gv = math_ops.matmul(v1, grad_v1, adjoint_a=True) f_u = f * u_gu f_v = f * v_gv term1_nouv = ( grad_s_mat + math_ops.matmul(f_u + _linalg.adjoint(f_u), s_mat) + math_ops.matmul(s_mat, f_v + _linalg.adjoint(f_v))) term1 = math_ops.matmul(u, math_ops.matmul(term1_nouv, v1, adjoint_b=True)) if m == n: grad_a_before_transpose = term1 else: gv1t = array_ops.matrix_transpose(grad_v1) gv1t_v1 = math_ops.matmul(gv1t, v1) term2_nous = gv1t - math_ops.matmul(gv1t_v1, v1, adjoint_b=True) if full_matrices: v2 = v[..., :, m:n] grad_v2 = grad_v[..., :, m:n] v1t_gv2 = math_ops.matmul(v1, grad_v2, adjoint_a=True) term2_nous -= math_ops.matmul(v1t_gv2, v2, adjoint_b=True) u_s_inv = math_ops.matmul(u, s_inv_mat) term2 = math_ops.matmul(u_s_inv, term2_nous) grad_a_before_transpose = term1 + term2 if use_adjoint: grad_a = array_ops.matrix_transpose(grad_a_before_transpose) else: grad_a = grad_a_before_transpose grad_a.set_shape(a_shape) return grad_a