aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cholesky_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-31 04:23:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 04:27:09 -0700
commit473a590c9cd26cdde1e77117778e3fd50a36d7df (patch)
tree61215c75a17e27998fd3d88611096ef93164fa1d /tensorflow/core/kernels/cholesky_op.cc
parent2d1860859a812437d5c20fa3bf75e6e989fbbb87 (diff)
Allow complex valued input for Cholesky decomposition.
PiperOrigin-RevId: 157572536
Diffstat (limited to 'tensorflow/core/kernels/cholesky_op.cc')
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 10595faf4b..5c7102f6f6 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -14,8 +14,7 @@ limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
-// TODO(konstantinos): Enable complex inputs. This will require additional tests
-// and OP_REQUIRES.
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
@@ -85,8 +84,10 @@ namespace functor {
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
-TF_CALL_float(DECLARE_GPU_SPEC);
-TF_CALL_double(DECLARE_GPU_SPEC);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+TF_CALL_complex64(DECLARE_GPU_SPEC);
+TF_CALL_complex128(DECLARE_GPU_SPEC);
+
} // namespace functor
template <class Scalar>
@@ -171,11 +172,15 @@ class CholeskyOpGpu : public AsyncOpKernel {
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
+REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64);
+REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128);
#endif // GOOGLE_CUDA
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64);
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);