diff options
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op.cc')
-rw-r--r-- | tensorflow/core/kernels/depthwise_conv_op.cc | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index bbeeaf7895..2759ecb2f1 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -94,7 +94,7 @@ struct DepthwiseConv2DKernel { for (int i = 0; i < output_vectorized_size; i += kPacketSize) { // Reset accumulator. - auto vaccum = Eigen::internal::pset1<Packet>(0); + auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0)); for (int j = 0; j < filter_spatial_size; ++j) { // Calculate index. const int64 index = i + j * padded_filter_inner_dim_size; @@ -115,7 +115,7 @@ struct DepthwiseConv2DKernel { } if (output_scalar_size > 0) { - auto vaccum = Eigen::internal::pset1<Packet>(0); + auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0)); for (int j = 0; j < filter_spatial_size; ++j) { const int64 index = output_vectorized_size + j * padded_filter_inner_dim_size; @@ -246,6 +246,7 @@ extern template class LaunchConv2DOp<CPUDevice, float>; #if GOOGLE_CUDA // Extern template instantiated in depthwise_conv_op_gpu.cc. +extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>; extern template struct LaunchDepthwiseConvOp<GPUDevice, float>; extern template struct LaunchDepthwiseConvOp<GPUDevice, double>; @@ -419,6 +420,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> { Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ DepthwiseConv2dNativeOp<CPUDevice, T>); +TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); #if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG) TF_CALL_double(REGISTER_CPU_KERNEL); @@ -426,6 +428,10 @@ TF_CALL_double(REGISTER_CPU_KERNEL); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( + Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), + DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>); + +REGISTER_KERNEL_BUILDER( Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"), DepthwiseConv2dNativeOp<GPUDevice, float>); |