aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_fused.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_fused.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_fused.cc12
1 files changed, 8 insertions, 4 deletions
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc
index 1b40ad81f4..972100ba77 100644
--- a/tensorflow/core/kernels/conv_ops_fused.cc
+++ b/tensorflow/core/kernels/conv_ops_fused.cc
@@ -195,7 +195,7 @@ EIGEN_ALWAYS_INLINE PerCacheLineParameters<T1> CalculatePerCacheLineParameters(
const int64 bottom_y_index =
std::min(static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
// Lerp is used for bilinear filtering when that's needed.
- result.y_lerp = in_y - top_y_index;
+ result.y_lerp = static_cast<T1>(in_y - top_y_index);
// Which rows of the original input image to pull the values from.
result.input_top_row_start =
input_batch_start + (top_y_index * input_width * input_depth);
@@ -245,7 +245,7 @@ CalculatePerCachePixelParameters(int64 cache_x, int64 cache_start_x,
result.right_x_index =
std::min(static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
// This x_lerp is used to blend pixels in bilinear filtering.
- result.x_lerp = in_x - result.left_x_index;
+ result.x_lerp = static_cast<T1>(in_x - result.left_x_index);
return result;
}
@@ -465,8 +465,8 @@ class FusedResizeAndPadConvFunctor {
// for that operation are always present.
// Work out the parameters that remain constant across the
// row we're calculating.
- PerCacheLineParameters<float> line_params(
- CalculatePerCacheLineParameters<float>(
+ PerCacheLineParameters<T1> line_params(
+ CalculatePerCacheLineParameters<T1>(
task_params.cache_height, cache_y,
task_params.resize_cache,
task_params.cache_line_width, task_params.input_width,
@@ -881,7 +881,9 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
BILINEAR>, \
true>);
+TF_CALL_half(REGISTER_FUSED);
TF_CALL_float(REGISTER_FUSED);
+TF_CALL_double(REGISTER_FUSED);
#define REGISTER_PAD_ONLY_FUSED(T) \
REGISTER_KERNEL_BUILDER( \
@@ -892,6 +894,8 @@ TF_CALL_float(REGISTER_FUSED);
NEAREST>, \
false>);
+TF_CALL_half(REGISTER_PAD_ONLY_FUSED);
TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
+TF_CALL_double(REGISTER_PAD_ONLY_FUSED);
} // namespace tensorflow