diff options
Diffstat (limited to 'tensorflow/core/kernels/self_adjoint_eig_v2_op.cc')
-rw-r--r-- | tensorflow/core/kernels/self_adjoint_eig_v2_op.cc | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc index c647d3aaac..7a1db4e558 100644 --- a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc +++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc @@ -69,7 +69,7 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> { errors::InvalidArgument("Self Adjoint Eigen decomposition was not " "successful. The input might not be valid.")); - outputs->at(0) = eig.eigenvalues(); + outputs->at(0) = eig.eigenvalues().template cast<Scalar>(); if (compute_v_) { outputs->at(1) = eig.eigenvectors(); } @@ -81,7 +81,15 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> { REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float); REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<double>), double); +REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>), + complex64); +REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>), + complex128); REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float); REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<double>), double); +REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>), + complex64); +REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>), + complex128); } // namespace tensorflow |