aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/self_adjoint_eig_v2_op.cc')
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op.cc10
1 files changed, 9 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
index c647d3aaac..7a1db4e558 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op.cc
@@ -69,7 +69,7 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
errors::InvalidArgument("Self Adjoint Eigen decomposition was not "
"successful. The input might not be valid."));
- outputs->at(0) = eig.eigenvalues();
+ outputs->at(0) = eig.eigenvalues().template cast<Scalar>();
if (compute_v_) {
outputs->at(1) = eig.eigenvectors();
}
@@ -81,7 +81,15 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<double>), double);
+REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
+ complex64);
+REGISTER_LINALG_OP("SelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
+ complex128);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<float>), float);
REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<double>),
double);
+REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex64>),
+ complex64);
+REGISTER_LINALG_OP("BatchSelfAdjointEigV2", (SelfAdjointEigV2Op<complex128>),
+ complex128);
} // namespace tensorflow