aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_fused.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_fused.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_fused.cc118
1 files changed, 84 insertions, 34 deletions
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc
index a041c6e9f8..697ee5d25a 100644
--- a/tensorflow/core/kernels/conv_ops_fused.cc
+++ b/tensorflow/core/kernels/conv_ops_fused.cc
@@ -46,9 +46,16 @@ namespace {
// In this case, we've picked 16 megabytes as a reasonable limit.
const size_t kMaxChunkSize = (16 * 1024 * 1024);
+// Lookup method used when resizing.
+enum SamplingMode {
+ BILINEAR = 0,
+ NEAREST = 1,
+};
+
// Combines bilinear resizing and mirror padding into the im2col transformation
-// stage of convolution,
-template <class T1, class T2, class T3, class TGemmFunctor>
+// stage of convolution.
+template <class T1, class T2, class T3, class TGemmFunctor,
+ SamplingMode SampleMode>
class FusedResizeAndPadConvFunctor {
public:
void operator()(OpKernelContext* context, const Tensor& input,
@@ -78,6 +85,9 @@ class FusedResizeAndPadConvFunctor {
<< output_width << ", " << output_height;
return;
}
+ OP_REQUIRES(
+ context, ((SampleMode == NEAREST) || (SampleMode == BILINEAR)),
+ errors::InvalidArgument("Bad sample mode passed in", SampleMode));
// These calculations define how the patches will be positioned within the
// input image. The actual definitions are quite complex, and rely on the
@@ -183,18 +193,24 @@ class FusedResizeAndPadConvFunctor {
T1 in_value;
if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
(conv_in_y >= 0) && (conv_in_y < padded_height)) {
- const T1 top_left(
- input_data(batch, top_y_index, left_x_index, in_channel));
- const T1 top_right(input_data(batch, top_y_index,
- right_x_index, in_channel));
- const T1 bottom_left(input_data(batch, bottom_y_index,
- left_x_index, in_channel));
- const T1 bottom_right(input_data(batch, bottom_y_index,
- right_x_index, in_channel));
- const T1 top = top_left + (top_right - top_left) * x_lerp;
- const T1 bottom =
- bottom_left + (bottom_right - bottom_left) * x_lerp;
- in_value = top + (bottom - top) * y_lerp;
+ if (SampleMode == NEAREST) {
+ const T1 top_left(input_data(batch, top_y_index,
+ left_x_index, in_channel));
+ in_value = top_left;
+ } else if (SampleMode == BILINEAR) {
+ const T1 top_left(input_data(batch, top_y_index,
+ left_x_index, in_channel));
+ const T1 top_right(input_data(batch, top_y_index,
+ right_x_index, in_channel));
+ const T1 bottom_left(input_data(batch, bottom_y_index,
+ left_x_index, in_channel));
+ const T1 bottom_right(input_data(
+ batch, bottom_y_index, right_x_index, in_channel));
+ const T1 top = top_left + (top_right - top_left) * x_lerp;
+ const T1 bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ in_value = top + (bottom - top) * y_lerp;
+ }
} else {
in_value = T1(0);
}
@@ -208,8 +224,8 @@ class FusedResizeAndPadConvFunctor {
((batch == (input_batches - 1)) &&
(out_y == (output_height - 1)) && (out_x == (output_width - 1)));
if (is_last_in_chunk || is_last_overall) {
- // Now we've assembled a set of image patches into a matrix, apply a
- // GEMM matrix multiply of the patches as rows, times the filter
+ // Now we've assembled a set of image patches into a matrix, apply
+ // a GEMM matrix multiply of the patches as rows, times the filter
// weights in columns, to get partial results in the output matrix.
const int how_many_patches = patch_index_within_chunk + 1;
const int m = how_many_patches;
@@ -236,13 +252,15 @@ class FusedResizeAndPadConvFunctor {
// Implements a version of convolution with bilinear resizing and mirror padding
// included.
-template <class T, class TConvFunctor>
+template <class T, class TConvFunctor, bool DoResize>
class FusedResizeConv2DUsingGemmOp : public OpKernel {
public:
explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
: OpKernel(context) {
- OP_REQUIRES_OK(context,
- context->GetAttr("resize_align_corners", &align_corners_));
+ if (DoResize) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("resize_align_corners", &align_corners_));
+ }
MirrorPadMode mode;
OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
@@ -280,13 +298,34 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
OP_REQUIRES(context, (input.shape().num_elements() > 0),
errors::InvalidArgument("Input tensor can't be empty"));
- ImageResizerState st(align_corners_);
- st.ValidateAndCalculateOutputSize(context, input);
- if (!context->status().ok()) return;
- const TensorShape resized_shape(
+ ImageResizerState st(false);
+ if (DoResize) {
+ st = ImageResizerState(align_corners_);
+ st.ValidateAndCalculateOutputSize(context, input);
+ if (!context->status().ok()) return;
+ } else {
+ // Set up the resize parameters to do no scaling at all.
+ st.batch_size = input.dim_size(0);
+ st.out_height = input.dim_size(1);
+ st.out_width = input.dim_size(2);
+ st.in_height = input.dim_size(1);
+ st.in_width = input.dim_size(2);
+ st.channels = input.dim_size(3);
+ st.height_scale = 1.0f;
+ st.width_scale = 1.0f;
+ }
+ TensorShape resized_shape(
{input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
-
- const Tensor& paddings = context->input(2);
+ int paddings_index;
+ int filter_index;
+ if (DoResize) {
+ paddings_index = 2;
+ filter_index = 3;
+ } else {
+ paddings_index = 1;
+ filter_index = 2;
+ }
+ const Tensor& paddings = context->input(paddings_index);
const int dims = resized_shape.dims();
OP_REQUIRES(
@@ -365,7 +404,7 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
- const Tensor& filter = context->input(3);
+ const Tensor& filter = context->input(filter_index);
// For 2D convolution, there should be 4 dimensions.
OP_REQUIRES(context, padded_shape.dims() == 4,
@@ -473,15 +512,26 @@ class FusedResizeConv2DUsingGemmOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
};
-#define REGISTER_FUSED(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("FusedResizeAndPadConv2D") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T"), \
- FusedResizeConv2DUsingGemmOp< \
- T, \
- FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
+#define REGISTER_FUSED(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FusedResizeAndPadConv2D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ FusedResizeConv2DUsingGemmOp< \
+ T, FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
+ BILINEAR>, \
+ true>);
TF_CALL_float(REGISTER_FUSED);
+#define REGISTER_PAD_ONLY_FUSED(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FusedPadConv2D").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ FusedResizeConv2DUsingGemmOp< \
+ T, FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>, \
+ NEAREST>, \
+ false>);
+
+TF_CALL_float(REGISTER_PAD_ONLY_FUSED);
+
} // namespace tensorflow