diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-05-31 04:23:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-05-31 04:27:09 -0700 |
commit | 473a590c9cd26cdde1e77117778e3fd50a36d7df (patch) | |
tree | 61215c75a17e27998fd3d88611096ef93164fa1d /tensorflow/core/kernels/cholesky_op.cc | |
parent | 2d1860859a812437d5c20fa3bf75e6e989fbbb87 (diff) |
Allow complex valued input for Cholesky decomposition.
PiperOrigin-RevId: 157572536
Diffstat (limited to 'tensorflow/core/kernels/cholesky_op.cc')
-rw-r--r-- | tensorflow/core/kernels/cholesky_op.cc | 13 |
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc index 10595faf4b..5c7102f6f6 100644 --- a/tensorflow/core/kernels/cholesky_op.cc +++ b/tensorflow/core/kernels/cholesky_op.cc @@ -14,8 +14,7 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/linalg_ops.cc. -// TODO(konstantinos): Enable complex inputs. This will require additional tests -// and OP_REQUIRES. + #if GOOGLE_CUDA #define EIGEN_USE_GPU #endif // GOOGLE_CUDA @@ -85,8 +84,10 @@ namespace functor { typename TTypes<T, 3>::Tensor output); \ extern template struct MatrixBandPart<GPUDevice, T>; -TF_CALL_float(DECLARE_GPU_SPEC); -TF_CALL_double(DECLARE_GPU_SPEC); +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +TF_CALL_complex64(DECLARE_GPU_SPEC); +TF_CALL_complex128(DECLARE_GPU_SPEC); + } // namespace functor template <class Scalar> @@ -171,11 +172,15 @@ class CholeskyOpGpu : public AsyncOpKernel { REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float); REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double); +REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64); +REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128); #endif // GOOGLE_CUDA REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float); REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double); +REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64); +REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128); REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float); REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double); |