aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/depthwise_conv_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/depthwise_conv_op.cc')
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index bbeeaf7895..2759ecb2f1 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -94,7 +94,7 @@ struct DepthwiseConv2DKernel {
for (int i = 0; i < output_vectorized_size; i += kPacketSize) {
// Reset accumulator.
- auto vaccum = Eigen::internal::pset1<Packet>(0);
+ auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
for (int j = 0; j < filter_spatial_size; ++j) {
// Calculate index.
const int64 index = i + j * padded_filter_inner_dim_size;
@@ -115,7 +115,7 @@ struct DepthwiseConv2DKernel {
}
if (output_scalar_size > 0) {
- auto vaccum = Eigen::internal::pset1<Packet>(0);
+ auto vaccum = Eigen::internal::pset1<Packet>(static_cast<T>(0));
for (int j = 0; j < filter_spatial_size; ++j) {
const int64 index =
output_vectorized_size + j * padded_filter_inner_dim_size;
@@ -246,6 +246,7 @@ extern template class LaunchConv2DOp<CPUDevice, float>;
#if GOOGLE_CUDA
// Extern template instantiated in depthwise_conv_op_gpu.cc.
+extern template struct LaunchDepthwiseConvOp<GPUDevice, Eigen::half>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, float>;
extern template struct LaunchDepthwiseConvOp<GPUDevice, double>;
@@ -419,6 +420,7 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
Name("DepthwiseConv2dNative").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
DepthwiseConv2dNativeOp<CPUDevice, T>);
+TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
#if !defined(PLATFORM_WINDOWS) || !defined(_DEBUG)
TF_CALL_double(REGISTER_CPU_KERNEL);
@@ -426,6 +428,10 @@ TF_CALL_double(REGISTER_CPU_KERNEL);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
+ Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+ DepthwiseConv2dNativeOp<GPUDevice, Eigen::half>);
+
+REGISTER_KERNEL_BUILDER(
Name("DepthwiseConv2dNative").Device(DEVICE_GPU).TypeConstraint<float>("T"),
DepthwiseConv2dNativeOp<GPUDevice, float>);