aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_grad_ops_3d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_grad_ops_3d.cc')
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc42
1 files changed, 26 insertions, 16 deletions
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 21f5cb1716..f819fccbfb 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -236,6 +236,7 @@ class Conv3DBackpropInputOp : public OpKernel {
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
Conv3DBackpropInputOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
@@ -383,6 +384,7 @@ class Conv3DBackpropFilterOp : public OpKernel {
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
@@ -409,6 +411,7 @@ namespace functor {
const std::array<int, 3>& padding_right, \
typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
+DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -1098,22 +1101,29 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
bool cudnn_use_autotune_;
};
-REGISTER_KERNEL_BUILDER(
- Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- Conv3DBackpropInputOp<GPUDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("input_sizes"),
- Conv3DBackpropInputOp<GPUDevice, float>);
-REGISTER_KERNEL_BUILDER(
- Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<float>("T"),
- Conv3DBackpropFilterOp<GPUDevice, float>);
-REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T")
- .HostMemory("filter_sizes"),
- Conv3DBackpropFilterOp<GPUDevice, float>);
+
+
+#define REGISTER_GPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropInput").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("input_sizes"), \
+ Conv3DBackpropInputOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<GPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<T>("T") \
+ .HostMemory("filter_sizes"), \
+ Conv3DBackpropFilterOp<GPUDevice, T>);
+TF_CALL_half(REGISTER_GPU_KERNEL);
+TF_CALL_float(REGISTER_GPU_KERNEL);
+#undef REGISTER_GPU_KERNEL
+
#endif // GOOGLE_CUDA
} // namespace tensorflow