diff options
author | 2017-09-22 11:34:01 -0700 | |
---|---|---|
committer | 2017-09-22 11:38:04 -0700 | |
commit | e3413de529c3f762885efd62932f76445ed22653 (patch) | |
tree | d8e9dab8736bb7420c6a161e9795f8299acf0f7e /tensorflow/core/kernels/qr_op_impl.h | |
parent | f927a72031ce563d65cf8864fe142ceb173444f5 (diff) |
Add GPU support for self_adjoint_eig a.k.a. tf.linalg.eigh.
Clean up macros and template specializations in cuda_solvers.cc a bit.
PiperOrigin-RevId: 169715681
Diffstat (limited to 'tensorflow/core/kernels/qr_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/qr_op_impl.h | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h index 431b083eef..b9843428a5 100644 --- a/tensorflow/core/kernels/qr_op_impl.h +++ b/tensorflow/core/kernels/qr_op_impl.h @@ -248,12 +248,12 @@ class QrOpGpu : public AsyncOpKernel { auto q_reshaped = q->flat_inner_dims<Scalar, 3>(); eye(device, q_reshaped); for (int batch = 0; batch < batch_size; ++batch) { - // Notice: It appears that Ormqr does not write a zero into *info upon + // Notice: It appears that Unmqr does not write a zero into *info upon // success (probably a bug), so we simply re-use the info array already // zeroed by Geqrf above. OP_REQUIRES_OK_ASYNC( context, - solver.Ormqr(CUBLAS_SIDE_LEFT, CublasAdjointOp<Scalar>(), m, m, + solver.Unmqr(CUBLAS_SIDE_LEFT, CublasAdjointOp<Scalar>(), m, m, min_size, &input_transposed_reshaped(batch, 0, 0), m, &tau_matrix(batch, 0), &q_reshaped(batch, 0, 0), m, dev_info.back().mutable_data() + batch), @@ -266,12 +266,12 @@ class QrOpGpu : public AsyncOpKernel { } } else { // Generate m x n matrix Q. In this case we can use the more efficient - // algorithm in Orgqr to generate Q in place. + // algorithm in Ungqr to generate Q in place. dev_info.emplace_back(context, batch_size, "orgqr"); for (int batch = 0; batch < batch_size; ++batch) { OP_REQUIRES_OK_ASYNC( context, - solver.Orgqr( + solver.Ungqr( m, n, min_size, &input_transposed_reshaped(batch, 0, 0), m, &tau_matrix(batch, 0), dev_info.back().mutable_data() + batch), done); |