diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_fused.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_fused.cc | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc index 1b40ad81f4..972100ba77 100644 --- a/tensorflow/core/kernels/conv_ops_fused.cc +++ b/tensorflow/core/kernels/conv_ops_fused.cc @@ -195,7 +195,7 @@ EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters( const int64 bottom_y_index = std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1)); // Lerp is used for bilinear filtering when that's needed. - result.y_lerp = in_y - top_y_index; + result.y_lerp = static_cast<T1>(in_y - top_y_index); // Which rows of the original input image to pull the values from. result.input_top_row_start = input_batch_start + (top_y_index * input_width * input_depth); @@ -245,7 +245,7 @@ CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x, result.right_x_index = std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1)); // This x_lerp is used to blend pixels in bilinear filtering. - result.x_lerp = in_x - result.left_x_index; + result.x_lerp = static_cast<T1>(in_x - result.left_x_index); return result; } @@ -465,8 +465,8 @@ class FusedResizeAndPadConvFunctor { // for that operation are always present. // Work out the parameters that remain constant across the // row we're calculating. - PerCacheLineParameters<float> line_params( - CalculatePerCacheLineParameters<float>( + PerCacheLineParameters<T1> line_params( + CalculatePerCacheLineParameters<T1>( task_params.cache_height, cache_y, task_params.resize_cache, task_params.cache_line_width, task_params.input_width, @@ -881,7 +881,9 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { BILINEAR>, \ true>); +TF_CALL_half(REGISTER_FUSED); TF_CALL_float(REGISTER_FUSED); +TF_CALL_double(REGISTER_FUSED); #define REGISTER_PAD_ONLY_FUSED(T) \ REGISTER_KERNEL_BUILDER( \ @@ -892,6 +894,8 @@ TF_CALL_float(REGISTER_FUSED); NEAREST>, \ false>); +TF_CALL_half(REGISTER_PAD_ONLY_FUSED); TF_CALL_float(REGISTER_PAD_ONLY_FUSED); +TF_CALL_double(REGISTER_PAD_ONLY_FUSED); } // namespace tensorflow |