diff options
Diffstat (limited to 'tensorflow/core/kernels/matrix_diag_op.cc')
-rw-r--r-- | tensorflow/core/kernels/matrix_diag_op.cc | 6 |
1 files changed, 4 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc index bc193357ad..75c49baaa8 100644 --- a/tensorflow/core/kernels/matrix_diag_op.cc +++ b/tensorflow/core/kernels/matrix_diag_op.cc @@ -123,7 +123,7 @@ class MatrixDiagOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("MatrixDiagPart").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ MatrixDiagPartOp<CPUDevice, type>); -TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG); +TF_CALL_POD_TYPES(REGISTER_MATRIX_DIAG); #undef REGISTER_MATRIX_DIAG // Registration of the deprecated kernel. @@ -136,7 +136,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_MATRIX_DIAG); .Device(DEVICE_CPU) \ .TypeConstraint<type>("T"), \ MatrixDiagPartOp<CPUDevice, type>); -TF_CALL_NUMBER_TYPES(REGISTER_BATCH_MATRIX_DIAG); +TF_CALL_POD_TYPES(REGISTER_BATCH_MATRIX_DIAG); #undef REGISTER_BATCH_MATRIX_DIAG // Implementation of the functor specialization for CPU. @@ -187,6 +187,7 @@ namespace functor { extern template struct MatrixDiagPart<GPUDevice, T>; TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); +TF_CALL_bool(DECLARE_GPU_SPEC); TF_CALL_complex64(DECLARE_GPU_SPEC); TF_CALL_complex128(DECLARE_GPU_SPEC); @@ -201,6 +202,7 @@ TF_CALL_complex128(DECLARE_GPU_SPEC); Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ MatrixDiagPartOp<GPUDevice, type>); TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU); +TF_CALL_bool(REGISTER_MATRIX_DIAG_GPU); TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU); TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU); #undef REGISTER_MATRIX_DIAG_GPU |