aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/matrix_diag_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/matrix_diag_op.cc')
-rw-r--r--tensorflow/core/kernels/matrix_diag_op.cc6
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc
index bc193357ad..75c49baaa8 100644
--- a/tensorflow/core/kernels/matrix_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_diag_op.cc
@@ -123,7 +123,7 @@ class MatrixDiagOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>);
-TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG);
+TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG);
#undef REGISTER_MATRIX_DIAG
// Registration of the deprecated kernel.
@@ -136,7 +136,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG);
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T"), \
MatrixDiagPartOp<CPUDevice, type>);
-TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG);
+TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG);
#undef REGISTER_BATCH_MATRIX_DIAG
// Implementation of the functor specialization for CPU.
@@ -187,6 +187,7 @@ namespace functor {
extern template struct MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+TF_CALL_bool(DECLARE_GPU_SPEC);
TF_CALL_complex64(DECLARE_GPU_SPEC);
TF_CALL_complex128(DECLARE_GPU_SPEC);
@@ -201,6 +202,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC);
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
+TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
#undef REGISTER_MATRIX_DIAG_GPU