diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_3d.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_3d.cc | 5 |
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index 8a89d564de..37cb67bc51 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -145,6 +145,7 @@ class Conv3DOp : public BinaryOp<T> { REGISTER_KERNEL_BUILDER( \ Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ Conv3DOp<CPUDevice, T>); +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL @@ -482,6 +483,7 @@ namespace functor { const std::array<int, 3>& padding_right, \ typename TTypes<T, 5, int>::Tensor out, TensorFormat format); +DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC @@ -489,6 +491,9 @@ DECLARE_GPU_SPEC(float); // Registration of the GPU implementations. REGISTER_KERNEL_BUILDER( + Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), + Conv3DOp<GPUDevice, Eigen::half>); +REGISTER_KERNEL_BUILDER( Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"), Conv3DOp<GPUDevice, float>); #endif // GOOGLE_CUDA |