aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_3d.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_3d.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc5
1 files changed, 5 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 8a89d564de..37cb67bc51 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -145,6 +145,7 @@ class Conv3DOp : public BinaryOp<T> {
REGISTER_KERNEL_BUILDER( \
Name("Conv3D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
Conv3DOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
@@ -482,6 +483,7 @@ namespace functor {
const std::array<int, 3>& padding_right, \
typename TTypes<T, 5, int>::Tensor out, TensorFormat format);
+DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
@@ -489,6 +491,9 @@ DECLARE_GPU_SPEC(float);
// Registration of the GPU implementations.
REGISTER_KERNEL_BUILDER(
+ Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+ Conv3DOp<GPUDevice, Eigen::half>);
+REGISTER_KERNEL_BUILDER(
Name("Conv3D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
Conv3DOp<GPUDevice, float>);
#endif // GOOGLE_CUDA