diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_fused.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_fused.cc | 118 |
1 files changed, 84 insertions, 34 deletions
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc index a041c6e9f8..697ee5d25a 100644 --- a/tensorflow/core/kernels/conv_ops_fused.cc +++ b/tensorflow/core/kernels/conv_ops_fused.cc @@ -46,9 +46,16 @@ namespace { // In this case, we've picked 16 megabytes as a reasonable limit. const size_t kMaxChunkSize = (16 * 1024 * 1024); +// Lookup method used when resizing. +enum SamplingMode { + BILINEAR = 0, + NEAREST = 1, +}; + // Combines bilinear resizing and mirror padding into the im2col transformation -// stage of convolution, -template <class T1, class T2, class T3, class TGemmFunctor> +// stage of convolution. +template <class T1, class T2, class T3, class TGemmFunctor, + SamplingMode SampleMode> class FusedResizeAndPadConvFunctor { public: void operator()(OpKernelContext* context, const Tensor& input, @@ -78,6 +85,9 @@ class FusedResizeAndPadConvFunctor { << output_width << ", " << output_height; return; } + OP_REQUIRES( + context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)), + errors::InvalidArgument("Bad sample mode passed in", SampleMode)); // These calculations define how the patches will be positioned within the // input image. The actual definitions are quite complex, and rely on the @@ -183,18 +193,24 @@ class FusedResizeAndPadConvFunctor { T1 in_value; if ((conv_in_x >= 0) && (conv_in_x < padded_width) && (conv_in_y >= 0) && (conv_in_y < padded_height)) { - const T1 top_left( - input_data(batch, top_y_index, left_x_index, in_channel)); - const T1 top_right(input_data(batch, top_y_index, - right_x_index, in_channel)); - const T1 bottom_left(input_data(batch, bottom_y_index, - left_x_index, in_channel)); - const T1 bottom_right(input_data(batch, bottom_y_index, - right_x_index, in_channel)); - const T1 top = top_left + (top_right - top_left) * x_lerp; - const T1 bottom = - bottom_left + (bottom_right - bottom_left) * x_lerp; - in_value = top + (bottom - top) * y_lerp; + if (SampleMode == NEAREST) { + const T1 top_left(input_data(batch, top_y_index, + left_x_index, in_channel)); + in_value = top_left; + } else if (SampleMode == BILINEAR) { + const T1 top_left(input_data(batch, top_y_index, + left_x_index, in_channel)); + const T1 top_right(input_data(batch, top_y_index, + right_x_index, in_channel)); + const T1 bottom_left(input_data(batch, bottom_y_index, + left_x_index, in_channel)); + const T1 bottom_right(input_data( + batch, bottom_y_index, right_x_index, in_channel)); + const T1 top = top_left + (top_right - top_left) * x_lerp; + const T1 bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + in_value = top + (bottom - top) * y_lerp; + } } else { in_value = T1(0); } @@ -208,8 +224,8 @@ class FusedResizeAndPadConvFunctor { ((batch == (input_batches - 1)) && (out_y == (output_height - 1)) && (out_x == (output_width - 1))); if (is_last_in_chunk || is_last_overall) { - // Now we've assembled a set of image patches into a matrix, apply a - // GEMM matrix multiply of the patches as rows, times the filter + // Now we've assembled a set of image patches into a matrix, apply + // a GEMM matrix multiply of the patches as rows, times the filter // weights in columns, to get partial results in the output matrix. const int how_many_patches = patch_index_within_chunk + 1; const int m = how_many_patches; @@ -236,13 +252,15 @@ class FusedResizeAndPadConvFunctor { // Implements a version of convolution with bilinear resizing and mirror padding // included. -template <class T, class TConvFunctor> +template <class T, class TConvFunctor, bool DoResize> class FusedResizeConv2DUsingGemmOp : public OpKernel { public: explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context) : OpKernel(context) { - OP_REQUIRES_OK(context, - context->GetAttr("resize_align_corners", &align_corners_)); + if (DoResize) { + OP_REQUIRES_OK(context, + context->GetAttr("resize_align_corners", &align_corners_)); + } MirrorPadMode mode; OP_REQUIRES_OK(context, context->GetAttr("mode", &mode)); @@ -280,13 +298,34 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { OP_REQUIRES(context, (input.shape().num_elements() > 0), errors::InvalidArgument("Input tensor can't be empty")); - ImageResizerState st(align_corners_); - st.ValidateAndCalculateOutputSize(context, input); - if (!context->status().ok()) return; - const TensorShape resized_shape( + ImageResizerState st(false); + if (DoResize) { + st = ImageResizerState(align_corners_); + st.ValidateAndCalculateOutputSize(context, input); + if (!context->status().ok()) return; + } else { + // Set up the resize parameters to do no scaling at all. + st.batch_size = input.dim_size(0); + st.out_height = input.dim_size(1); + st.out_width = input.dim_size(2); + st.in_height = input.dim_size(1); + st.in_width = input.dim_size(2); + st.channels = input.dim_size(3); + st.height_scale = 1.0f; + st.width_scale = 1.0f; + } + TensorShape resized_shape( {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)}); - - const Tensor& paddings = context->input(2); + int paddings_index; + int filter_index; + if (DoResize) { + paddings_index = 2; + filter_index = 3; + } else { + paddings_index = 1; + filter_index = 2; + } + const Tensor& paddings = context->input(paddings_index); const int dims = resized_shape.dims(); OP_REQUIRES( @@ -365,7 +404,7 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { // Input filter is of the following dimensions: // [ filter_rows, filter_cols, in_depth, out_depth] - const Tensor& filter = context->input(3); + const Tensor& filter = context->input(filter_index); // For 2D convolution, there should be 4 dimensions. OP_REQUIRES(context, padded_shape.dims() == 4, @@ -473,15 +512,26 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp); }; -#define REGISTER_FUSED(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("FusedResizeAndPadConv2D") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<T>("T"), \ - FusedResizeConv2DUsingGemmOp< \ - T, \ - FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>); +#define REGISTER_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedResizeAndPadConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + FusedResizeConv2DUsingGemmOp< \ + T, FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \ + BILINEAR>, \ + true>); TF_CALL_float(REGISTER_FUSED); +#define REGISTER_PAD_ONLY_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + FusedResizeConv2DUsingGemmOp< \ + T, FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \ + NEAREST>, \ + false>); + +TF_CALL_float(REGISTER_PAD_ONLY_FUSED); + } // namespace tensorflow |