diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_ops_3d.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_grad_ops_3d.cc | 42 |
1 files changed, 26 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 21f5cb1716..f819fccbfb 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -236,6 +236,7 @@ class Conv3DBackpropInputOp : public OpKernel { REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ Conv3DBackpropInputOp<CPUDevice, T>); +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL @@ -383,6 +384,7 @@ class Conv3DBackpropFilterOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint<T>("T"), \ Conv3DBackpropFilterOp<CPUDevice, T>); +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL @@ -409,6 +411,7 @@ namespace functor { const std::array<int, 3>& padding_right, \ typename TTypes<T, 5, int>::Tensor out, TensorFormat format); +DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC } // namespace functor @@ -1098,22 +1101,29 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { bool cudnn_use_autotune_; }; -REGISTER_KERNEL_BUILDER( - Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<float>("T"), - Conv3DBackpropInputOp<GPUDevice, float>); -REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") - .Device(DEVICE_GPU) - .TypeConstraint<float>("T") - .HostMemory("input_sizes"), - Conv3DBackpropInputOp<GPUDevice, float>); -REGISTER_KERNEL_BUILDER( - Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<float>("T"), - Conv3DBackpropFilterOp<GPUDevice, float>); -REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") - .Device(DEVICE_GPU) - .TypeConstraint<float>("T") - .HostMemory("filter_sizes"), - Conv3DBackpropFilterOp<GPUDevice, float>); + + +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("input_sizes"), \ + Conv3DBackpropInputOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<GPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<T>("T") \ + .HostMemory("filter_sizes"), \ + Conv3DBackpropFilterOp<GPUDevice, T>); +TF_CALL_half(REGISTER_GPU_KERNEL); +TF_CALL_float(REGISTER_GPU_KERNEL); +#undef REGISTER_GPU_KERNEL + #endif // GOOGLE_CUDA } // namespace tensorflow |