aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/makefile/tf_op_files.txt1
-rw-r--r--tensorflow/core/kernels/BUILD23
-rw-r--r--tensorflow/core/kernels/conv_ops.h11
-rw-r--r--tensorflow/core/kernels/conv_ops_fused.cc486
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc240
-rw-r--r--tensorflow/core/kernels/conv_ops_using_gemm.cc91
-rw-r--r--tensorflow/core/kernels/gemm_functors.h105
-rw-r--r--tensorflow/core/kernels/image_resizer_state.h17
-rw-r--r--tensorflow/core/kernels/ops_testutil.h1
-rw-r--r--tensorflow/core/ops/nn_ops.cc41
-rw-r--r--tensorflow/python/ops/nn_ops.py65
-rw-r--r--tensorflow/python/tools/optimize_for_inference.py2
-rw-r--r--tensorflow/python/tools/optimize_for_inference_lib.py142
-rw-r--r--tensorflow/python/tools/optimize_for_inference_test.py69
14 files changed, 1190 insertions, 104 deletions
diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index bc9838ec74..0478adab2a 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -101,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc
tensorflow/core/kernels/cwise_op_add.cc
tensorflow/core/kernels/ctc_decoder_ops.cc
tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/conv_ops_fused.cc
tensorflow/core/kernels/conv_ops.cc
tensorflow/core/kernels/conv_grad_ops.cc
tensorflow/core/kernels/control_flow_ops.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index adbe554507..6a1967eaf5 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -492,6 +492,27 @@ tf_cc_test(
)
tf_cc_test(
+ name = "conv_ops_test",
+ size = "small",
+ deps = [
+ ":conv_ops",
+ ":image",
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+tf_cc_test(
name = "example_parsing_ops_test",
size = "large",
deps = [
@@ -1325,6 +1346,7 @@ tf_kernel_library(
hdrs = [
"conv_grad_ops.h",
"deep_conv2d.h",
+ "gemm_functors.h",
"winograd_transform.h",
],
prefix = "conv_ops",
@@ -1332,6 +1354,7 @@ tf_kernel_library(
":bounds_check",
":conv_2d",
":conv_3d",
+ ":image_resizer_state",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index d09db3dc15..858be520b0 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_CONV_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/util/tensor_format.h"
#if GOOGLE_CUDA
@@ -38,6 +39,16 @@ class LaunchConv2DOp {
TensorFormat data_format);
};
+// Used to keep track of persistent memory buffers used within the op.
+template <class T, size_t size>
+struct Im2ColBufferResource : public ResourceBase {
+ // This mutex ensures that only a single operation at a time is able to use
+ // the buffer memory held by this resource.
+ mutex mu;
+ T data[size];
+ string DebugString() { return "Im2ColBufferResource"; }
+};
+
#ifdef GOOGLE_CUDA
template <typename T>
class LaunchConv2DOp<Eigen::GpuDevice, T> {
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc
new file mode 100644
index 0000000000..865021405a
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_fused.cc
@@ -0,0 +1,486 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implements convolution operations with other kernels baked into the
+// processing, to optimize latency and memory usage.
+
+#include <string.h>
+#include <map>
+#include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_ops.h"
+#include "tensorflow/core/kernels/gemm_functors.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Combines bilinear resizing and mirror padding into the im2col transformation
+// stage of convolution,
+template <class T1, class T2, class T3, class TGemmFunctor>
+class FusedResizeAndPadConvFunctor {
+ public:
+ void operator()(OpKernelContext* context, const Tensor& input,
+ int input_batches, int resized_height, int resized_width,
+ int padded_height, int padded_width, int input_depth,
+ const T2* filter_data, int filter_height, int filter_width,
+ int filter_count, int stride_rows, int stride_cols,
+ Padding padding, T3* output_data, int output_height,
+ int output_width, const ImageResizerState& st,
+ int top_padding, int bottom_padding, int left_padding,
+ int right_padding, int pad_offset) {
+ if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
+ (input_depth <= 0)) {
+ LOG(WARNING) << "Conv2D was called with bad input dimensions: "
+ << input_batches << ", " << padded_height << ", "
+ << padded_width << ", " << input_depth;
+ return;
+ }
+ if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
+ LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
+ << filter_width << ", " << filter_height << ", "
+ << filter_count;
+ return;
+ }
+ if ((output_width <= 0) || (output_height <= 0)) {
+ LOG(WARNING) << "Conv2D was called with bad output width or height: "
+ << output_width << ", " << output_height;
+ return;
+ }
+
+ // These calculations define how the patches will be positioned within the
+ // input image. The actual definitions are quite complex, and rely on the
+ // previously-calculated output size.
+ int filter_left_offset;
+ int filter_top_offset;
+ if (padding == VALID) {
+ filter_left_offset =
+ ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
+ 2;
+ filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
+ padded_height + 1) /
+ 2;
+ } else {
+ filter_left_offset =
+ ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
+ filter_top_offset =
+ ((output_height - 1) * stride_rows + filter_height - padded_height) /
+ 2;
+ }
+
+ // The im2col buffer has # of patches rows, and # of filters cols.
+ // It's laid out like this, in row major order in memory:
+ // < filter value count >
+ // ^ +---------------------+
+ // patch | |
+ // count | |
+ // v +---------------------+
+ // Each patch row contains a filter_width x filter_height patch of the
+ // input, with the depth channel as the most contiguous in memory, followed
+ // by the width, then the height. This is the standard memory order in the
+ // image world if it helps to visualize it.
+ const int filter_value_count = filter_width * filter_height * input_depth;
+
+ // We don't want to allocate a buffer to hold all the patches if the size is
+ // going to be extremely large, so break it into chunks if it's bigger than
+ // a limit. Each chunk will be processed serially, so we can refill the
+ // buffer for the next chunk and reuse it, keeping maximum memory size down.
+ // In this case, we've picked 16 megabytes as a reasonable limit.
+ const size_t max_chunk_size = (16 * 1024 * 1024);
+ OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size,
+ errors::InvalidArgument("Im2Col patch too large for buffer"));
+ const size_t patches_per_chunk =
+ max_chunk_size / (filter_value_count * sizeof(T1));
+ // Because memory allocation is very expensive on mobile platforms, try to
+ // allocate a persistent buffer that will be kept around between calls. We
+ // use TensorFlow's resource management to ensure that the memory will be
+ // released when the session is over.
+ Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource;
+ std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator =
+ [](Im2ColBufferResource<T1, max_chunk_size>** resource) {
+ *resource = new Im2ColBufferResource<T1, max_chunk_size>();
+ return Status::OK();
+ };
+ OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
+ "Conv2d", "im2col_buffer",
+ &im2col_buffer_resource, creator));
+ // This means that multiple ops can't be run simultaneously on different
+ // threads, because we have a single shared resource. The platforms this is
+ // aimed at have intra-op parallelism as their focus though, so it shouldn't
+ // be an issue.
+ mutex_lock lock_buffer(im2col_buffer_resource->mu);
+ core::ScopedUnref unref_buffer(im2col_buffer_resource);
+ T1* im2col_buffer = im2col_buffer_resource->data;
+
+ typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>();
+
+ for (int batch = 0; batch < input_batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
+ const int patch_index = (batch * output_width * output_height) +
+ (out_y * output_width) + out_x;
+ const int patch_index_within_chunk = patch_index % patches_per_chunk;
+ T1* im2col_patch_start =
+ im2col_buffer + (patch_index_within_chunk * filter_value_count);
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int conv_in_y = in_y_origin + filter_y;
+ float in_y = (conv_in_y - top_padding);
+ if (in_y < 0) {
+ in_y = -(in_y + 1.0f - pad_offset);
+ } else if (in_y >= resized_height) {
+ in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
+ }
+ in_y *= st.height_scale;
+ const int64 top_y_index = static_cast<int64>(std::floor(in_y));
+ const int64 bottom_y_index = std::min(
+ static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
+ const T1 y_lerp = in_y - top_y_index;
+ T1* im2col_row_start =
+ im2col_patch_start + (filter_y * filter_width * input_depth);
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int conv_in_x = in_x_origin + filter_x;
+ float in_x = (conv_in_x - left_padding);
+ if (in_x < 0) {
+ in_x = -(in_x + 1.0f - pad_offset);
+ } else if (in_x >= resized_width) {
+ in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
+ }
+ in_x *= st.width_scale;
+ const int64 left_x_index = static_cast<int64>(std::floor(in_x));
+ const int64 right_x_index = std::min(
+ static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
+ const T1 x_lerp = in_x - left_x_index;
+ T1* im2col_row_pixel =
+ im2col_row_start + (filter_x * input_depth);
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ T1 in_value;
+ if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
+ (conv_in_y >= 0) && (conv_in_y < padded_height)) {
+ const T1 top_left(
+ input_data(batch, top_y_index, left_x_index, in_channel));
+ const T1 top_right(input_data(batch, top_y_index,
+ right_x_index, in_channel));
+ const T1 bottom_left(input_data(batch, bottom_y_index,
+ left_x_index, in_channel));
+ const T1 bottom_right(input_data(batch, bottom_y_index,
+ right_x_index, in_channel));
+ const T1 top = top_left + (top_right - top_left) * x_lerp;
+ const T1 bottom =
+ bottom_left + (bottom_right - bottom_left) * x_lerp;
+ in_value = top + (bottom - top) * y_lerp;
+ } else {
+ in_value = T1(0);
+ }
+ im2col_row_pixel[in_channel] = in_value;
+ }
+ }
+ }
+ const bool is_last_in_chunk =
+ (patch_index_within_chunk == (patches_per_chunk - 1));
+ const bool is_last_overall =
+ ((batch == (input_batches - 1)) &&
+ (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
+ if (is_last_in_chunk || is_last_overall) {
+ // Now we've assembled a set of image patches into a matrix, apply a
+ // GEMM matrix multiply of the patches as rows, times the filter
+ // weights in columns, to get partial results in the output matrix.
+ const int how_many_patches = patch_index_within_chunk + 1;
+ const int m = how_many_patches;
+ const int n = filter_count;
+ const int k = filter_value_count;
+ const int lda = filter_value_count;
+ const int ldb = filter_count;
+ const int ldc = filter_count;
+ const size_t start_patch_index =
+ patch_index - (how_many_patches - 1);
+ T3* chunk_output_data =
+ output_data + (start_patch_index * filter_count);
+ TGemmFunctor gemm_functor;
+ gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb,
+ chunk_output_data, ldc);
+ }
+ }
+ }
+ }
+ }
+};
+
+} // namespace
+
+// Implements a version of convolution with bilinear resizing and mirror padding
+// included.
+template <class T, class TConvFunctor>
+class FusedResizeConv2DUsingGemmOp : public OpKernel {
+ public:
+ explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr("resize_align_corners", &align_corners_));
+ MirrorPadMode mode;
+ OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
+
+ switch (mode) {
+ case MirrorPadMode::SYMMETRIC: {
+ offset_ = 0;
+ break;
+ }
+ case MirrorPadMode::REFLECT: {
+ offset_ = 1;
+ break;
+ }
+ default:
+ OP_REQUIRES(context, false,
+ errors::InvalidArgument(
+ "mode must be either REFLECT or SYMMETRIC."));
+ }
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+ OP_REQUIRES(context, strides_.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
+ const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
+ OP_REQUIRES(
+ context, stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_rows, in_cols, in_depth ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, (input.shape().num_elements() > 0),
+ errors::InvalidArgument("Input tensor can't be empty"));
+
+ ImageResizerState st(align_corners_);
+ st.ValidateAndCalculateOutputSize(context, input);
+ if (!context->status().ok()) return;
+ const TensorShape resized_shape(
+ {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
+
+ const Tensor& paddings = context->input(2);
+
+ const int dims = resized_shape.dims();
+ OP_REQUIRES(
+ context, TensorShapeUtils::IsMatrix(paddings.shape()) &&
+ paddings.dim_size(1) == 2,
+ errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
+ paddings.shape().DebugString()));
+ const int fixed_dims =
+ (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
+ ? 1
+ : dims;
+ OP_REQUIRES(
+ context, fixed_dims == paddings.dim_size(0),
+ errors::InvalidArgument(
+ "The first dimension of paddings must be the rank of inputs: ",
+ fixed_dims, " ", paddings.shape().DebugString(), " ",
+ resized_shape.DebugString()));
+ OP_REQUIRES(
+ context, dims == paddings.dim_size(0),
+ errors::InvalidArgument(
+ "The first dimension of paddings must be the rank of inputs: ",
+ dims, " ", paddings.shape().DebugString(), " ",
+ resized_shape.DebugString()));
+
+ OP_REQUIRES(
+ context, dims == 4,
+ errors::InvalidArgument(
+ "Fused mirror padding only supports four-dimensional inputs, but ",
+ dims, " requested"));
+
+ // Compute the shape of the output tensor, and allocate it.
+ TensorShape padded_shape;
+ TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
+ for (int d = 0; d < dims; ++d) {
+ const int32 before =
+ paddings_matrix(d, 0); // Pad before existing elements.
+ const int32 after =
+ paddings_matrix(d, 1); // Pad after exisitng elements.
+ OP_REQUIRES(context, before >= 0 && after >= 0,
+ errors::InvalidArgument("paddings must be non-negative: ",
+ before, " ", after));
+ if (offset_ == 0) { // SYMMETRIC mode.
+ OP_REQUIRES(
+ context, before <= resized_shape.dim_size(d) &&
+ after <= resized_shape.dim_size(d),
+ errors::InvalidArgument("paddings must be no greater "
+ "than the dimension size: ",
+ before, ", ", after, " greater than ",
+ resized_shape.dim_size(d)));
+ } else if (offset_ == 1) { // REFLECT mode.
+ OP_REQUIRES(
+ context, before < resized_shape.dim_size(d) &&
+ after < resized_shape.dim_size(d),
+ errors::InvalidArgument("paddings must be less than"
+ " the dimension size: ",
+ before, ", ", after, " not less than ",
+ resized_shape.dim_size(d)));
+ }
+ padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
+ }
+
+ OP_REQUIRES(
+ context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
+ errors::InvalidArgument(
+ "Fused mirror padding only support spatial padding, not batches: ",
+ paddings.DebugString()));
+ OP_REQUIRES(
+ context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
+ errors::InvalidArgument(
+ "Fused mirror padding only support spatial padding, not channels: ",
+ paddings.DebugString()));
+ const int32 top_padding = paddings_matrix(1, 0);
+ const int32 bottom_padding = paddings_matrix(1, 1);
+ const int32 left_padding = paddings_matrix(2, 0);
+ const int32 right_padding = paddings_matrix(2, 1);
+
+ // Input filter is of the following dimensions:
+ // [ filter_rows, filter_cols, in_depth, out_depth]
+ const Tensor& filter = context->input(3);
+
+ // For 2D convolution, there should be 4 dimensions.
+ OP_REQUIRES(context, padded_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ padded_shape.DebugString()));
+ OP_REQUIRES(context, filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+
+ // We only check the first three dims, since the depth is accessed as an
+ // int64 below.
+ for (int i = 0; i < 3; i++) {
+ OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. It must be the same as the
+ // filter's in_depth.
+ const int64 in_depth = padded_shape.dim_size(3);
+ OP_REQUIRES(
+ context, in_depth == filter.dim_size(2),
+ errors::InvalidArgument("input and filter must have the same depth: ",
+ in_depth, " vs ", filter.dim_size(2)));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 padded_rows_raw = padded_shape.dim_size(1);
+ OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int padded_rows = static_cast<int>(padded_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+ const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 padded_cols_raw = padded_shape.dim_size(2);
+ OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw,
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int padded_cols = static_cast<int>(padded_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+ const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = padded_shape.dim_size(0);
+ OP_REQUIRES(context,
+ FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // For now we take the stride from the second and third dimensions only (we
+ // do not support striding on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
+ const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
+
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+ TensorShape out_shape =
+ ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
+ OP_REQUIRES(context, (out_shape.num_elements() > 0),
+ errors::InvalidArgument("Output tensor can't be empty"));
+
+ // Output tensor is of the following dimensions:
+ // [ in_batch, out_rows, out_cols, out_depth ]
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ VLOG(2) << "Conv2D: in_depth = " << in_depth
+ << ", padded_cols = " << padded_cols
+ << ", filter_cols = " << filter_cols
+ << ", padded_rows = " << padded_rows
+ << ", filter_rows = " << filter_rows
+ << ", stride_rows = " << stride_rows
+ << ", stride_cols = " << stride_cols
+ << ", out_depth = " << out_depth;
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+ TConvFunctor conv_functor;
+ conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
+ padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
+ filter_cols, out_depth, stride_rows, stride_cols, padding_,
+ output->flat<T>().data(), out_rows, out_cols, st, top_padding,
+ bottom_padding, left_padding, right_padding, offset_);
+ }
+
+ private:
+ std::vector<int32> strides_;
+ Padding padding_;
+ bool align_corners_;
+ int offset_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
+};
+
+#define REGISTER_FUSED(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("FusedResizeAndPadConv2D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ FusedResizeConv2DUsingGemmOp< \
+ T, \
+ FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
+
+TF_CALL_float(REGISTER_FUSED);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
new file mode 100644
index 0000000000..228f2d5def
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -0,0 +1,240 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+class FusedResizePadConvOpTest : public OpsTestBase {
+ protected:
+ void HandwrittenConv() {
+ const int stride = 1;
+ TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
+ .Input(FakeInput(DT_FLOAT))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(DT_FLOAT))
+ .Attr("T", DT_FLOAT)
+ .Attr("resize_align_corners", false)
+ .Attr("mode", "REFLECT")
+ .Attr("strides", {1, stride, stride, 1})
+ .Attr("padding", "SAME")
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ const int depth = 1;
+ const int image_width = 4;
+ const int image_height = 3;
+ const int image_batch_count = 1;
+ // The image matrix is:
+ // | 1 | 2 | 3 | 4 |
+ // | 5 | 6 | 7 | 8 |
+ // | 9 | 10 | 11 | 12 |
+ Tensor image(DT_FLOAT,
+ {image_batch_count, image_height, image_width, depth});
+ test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+
+ // The filter matrix is:
+ // | 1 | 4 | 7 |
+ // | 2 | 5 | 8 |
+ // | 3 | 6 | 9 |
+ const int filter_size = 3;
+ const int filter_count = 1;
+ Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
+ test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
+
+ const int resized_width = image_width;
+ const int resized_height = image_height;
+
+ const int top_padding = 0;
+ const int bottom_padding = 0;
+ const int left_padding = 0;
+ const int right_padding = 0;
+
+ AddInputFromArray<float>(image.shape(), image.flat<float>());
+ AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
+ AddInputFromArray<int32>(
+ TensorShape({4, 2}),
+ {0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
+ AddInputFromArray<float>(filter.shape(), filter.flat<float>());
+ TF_ASSERT_OK(RunOpKernel());
+
+ // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+ // the input set to zero because we're using the 'SAME' padding mode.
+ // The calculations behind the expected output are:
+ // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
+ // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
+ // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
+ // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
+ // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
+ // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+ // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+ // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
+ // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
+ // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
+ // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
+ // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
+ // This means we should end up with this matrix:
+ // | 105 | 150 | 183 | 95 |
+ // | 235 | 312 | 357 | 178 |
+ // | 187 | 234 | 261 | 121 |
+ const int expected_width = image_width;
+ const int expected_height = image_height * filter_count;
+ Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
+ expected_width, filter_count}));
+ test::FillValues<float>(
+ &expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
+ const Tensor& output = *GetOutput(0);
+ test::ExpectTensorNear<float>(expected, output, 1e-5);
+ }
+
+ void CompareFusedAndSeparate(int input_width, int input_height,
+ int input_depth, int resize_width,
+ int resize_height, int y_padding, int x_padding,
+ int filter_size, int filter_count,
+ bool resize_align_corners, string pad_mode,
+ int stride, string padding) {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ const size_t input_data_size = input_height * input_width * input_depth;
+ Tensor input_data(DT_FLOAT,
+ TensorShape({1, input_height, input_width, input_depth}));
+ for (int i = 0; i < input_data_size; ++i) {
+ input_data.flat<float>()(i) = i + 1.0f;
+ }
+ Output input =
+ Const(root.WithOpName("input"), Input::Initializer(input_data));
+
+ const size_t filter_data_size =
+ filter_size * filter_size * filter_count * input_depth;
+ Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
+ input_depth, filter_count}));
+ for (int i = 0; i < filter_data_size; ++i) {
+ filter_data.flat<float>()(i) = i + 1.0f;
+ }
+ Output filter =
+ Const(root.WithOpName("filter"), Input::Initializer(filter_data));
+
+ Output resize_size =
+ Const(root.WithOpName("resize_size"), {resize_height, resize_width});
+ Output resize =
+ ResizeBilinear(root.WithOpName("resize"), input, resize_size,
+ ResizeBilinear::AlignCorners(resize_align_corners));
+ Output paddings =
+ Const(root.WithOpName("paddings"),
+ {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
+ Output mirror_pad =
+ MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
+ Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
+ {1, stride, stride, 1}, padding);
+
+ Output fused_conv = FusedResizeAndPadConv2D(
+ root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
+ pad_mode, {1, stride, stride, 1}, padding,
+ FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
+
+ tensorflow::GraphDef graph;
+ TF_ASSERT_OK(root.ToGraphDef(&graph));
+
+ std::unique_ptr<tensorflow::Session> session(
+ tensorflow::NewSession(tensorflow::SessionOptions()));
+ TF_ASSERT_OK(session->Create(graph));
+
+ std::vector<Tensor> unfused_tensors;
+ TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
+
+ std::vector<Tensor> fused_tensors;
+ TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
+
+ test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ }
+};
+
+TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
+
+TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
+ CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
+ CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
+ CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
+ CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
+ CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
+ CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
+ CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+ "VALID");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
+ CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
+ CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
+ CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
+ CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
+ "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
+ CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
+ "SAME");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc
index c39510a11a..6da6da846b 100644
--- a/tensorflow/core/kernels/conv_ops_using_gemm.cc
+++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc
@@ -56,14 +56,13 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_ops.h"
+#include "tensorflow/core/kernels/gemm_functors.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
-#if defined(__APPLE__)
-#include <Accelerate/Accelerate.h>
-#define USE_ACCELERATE_GEMM
-#endif // __APPLE__
-
namespace tensorflow {
namespace {
@@ -189,87 +188,6 @@ class ReferenceConvFunctor {
}
};
-// A readable but slow implementation of matrix multiplication, useful for
-// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
-// the Im2ColConvFunctor template definition inside the op registration to
-// enable. Assumes row-major ordering of the values in memory.
-template <class T1, class T2, class T3>
-class ReferenceGemmFunctor {
- public:
- void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
- const T2* b, size_t ldb, T3* c, size_t ldc) {
- const size_t a_i_stride = lda;
- const size_t a_l_stride = 1;
- const size_t b_j_stride = 1;
- const size_t b_l_stride = ldb;
- const size_t c_i_stride = ldc;
- const size_t c_j_stride = 1;
- size_t i, j, l;
- for (j = 0; j < n; j++) {
- for (i = 0; i < m; i++) {
- T3 total(0);
- for (l = 0; l < k; l++) {
- const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
- const T1 a_value = a[a_index];
- const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
- const T2 b_value = b[b_index];
- total += (a_value * b_value);
- }
- const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
- c[c_index] = total;
- }
- }
- }
-};
-
-// Uses the optimized Eigen library to implement the matrix multiplication
-// required by the Im2ColConvFunctor class. We supply the two input and one
-// output types so that the accumulator can potentially be higher-precision than
-// the inputs, even though we don't currently take advantage of this.
-template <class T1, class T2, class T3>
-class FastGemmFunctor {
- public:
- // Convenience wrappers for the Eigen matrix types we'll be using.
- typedef Eigen::Map<
- const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
- ConstMatrixT1;
- typedef Eigen::Map<
- const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
- ConstMatrixT2;
- typedef Eigen::Map<
- Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
- MatrixT3;
- void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
- const T2* b, size_t ldb, T3* c, size_t ldc) {
- ConstMatrixT1 a_matrix(a, m, k);
- ConstMatrixT2 b_matrix(b, k, n);
- MatrixT3 c_matrix(c, m, n);
- c_matrix.noalias() = a_matrix * b_matrix;
- }
-};
-
-// If we have Apple's Accelerate framework, use their implementation of GEMM to
-// get a performance boost for float.
-#if defined(USE_ACCELERATE_GEMM)
-template <>
-class FastGemmFunctor<float, float, float> {
- public:
- void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
- const float* b, size_t ldb, float* c, size_t ldc) {
- cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
- lda, b, ldb, 0.0f, c, ldc);
- }
-};
-#endif // USE_ACCELERATE_GEMM
-
-// Used to keep track of persistent memory buffers used within the op.
-template <class T, size_t size>
-struct Im2ColBufferResource : public ResourceBase {
- mutex mu;
- T data[size];
- string DebugString() { return "Im2ColBufferResource"; }
-};
-
// Implements convolution as a two stage process, first packing the patches of
// the input image into columns (im2col) and then running GEMM to produce the
// final result.
@@ -344,7 +262,6 @@ class Im2ColConvFunctor {
errors::InvalidArgument("Im2Col patch too large for buffer"));
const size_t patches_per_chunk =
max_chunk_size / (filter_value_count * sizeof(T1));
-
// Because memory allocation is very expensive on mobile platforms, try to
// allocate a persistent buffer that will be kept around between calls. We
// use TensorFlow's resource management to ensure that the memory will be
diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h
new file mode 100644
index 0000000000..d37008d5cf
--- /dev/null
+++ b/tensorflow/core/kernels/gemm_functors.h
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is a set of different implementations for the basic matrix by matrix
+// multiply function, commonly known as GEMM after the BLAS library's naming.
+// Having a standard interface enables us to swap out implementations on
+// different platforms, to make sure we're using the optimal version. They are
+// implemented as C++ template functors, so they're easy to swap into all of the
+// different kernels that use them.
+
+#include <string.h>
+#include <map>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+
+#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
+#include <Accelerate/Accelerate.h>
+#define USE_ACCELERATE_GEMM
+#endif // __APPLE__
+
+// A readable but slow implementation of matrix multiplication, useful for
+// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
+// the Im2ColConvFunctor template definition inside the op registration to
+// enable. Assumes row-major ordering of the values in memory.
+template <class T1, class T2, class T3>
+class ReferenceGemmFunctor {
+ public:
+ void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
+ const T2* b, size_t ldb, T3* c, size_t ldc) {
+ const size_t a_i_stride = lda;
+ const size_t a_l_stride = 1;
+ const size_t b_j_stride = 1;
+ const size_t b_l_stride = ldb;
+ const size_t c_i_stride = ldc;
+ const size_t c_j_stride = 1;
+ size_t i, j, l;
+ for (j = 0; j < n; j++) {
+ for (i = 0; i < m; i++) {
+ T3 total(0);
+ for (l = 0; l < k; l++) {
+ const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
+ const T1 a_value = a[a_index];
+ const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
+ const T2 b_value = b[b_index];
+ total += (a_value * b_value);
+ }
+ const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
+ c[c_index] = total;
+ }
+ }
+ }
+};
+
+// Uses the optimized Eigen library to implement the matrix multiplication
+// required by the Im2ColConvFunctor class. We supply the two input and one
+// output types so that the accumulator can potentially be higher-precision than
+// the inputs, even though we don't currently take advantage of this.
+template <class T1, class T2, class T3>
+class FastGemmFunctor {
+ public:
+ // Convenience wrappers for the Eigen matrix types we'll be using.
+ typedef Eigen::Map<
+ const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ ConstMatrixT1;
+ typedef Eigen::Map<
+ const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ ConstMatrixT2;
+ typedef Eigen::Map<
+ Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixT3;
+ void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
+ const T2* b, size_t ldb, T3* c, size_t ldc) {
+ ConstMatrixT1 a_matrix(a, m, k);
+ ConstMatrixT2 b_matrix(b, k, n);
+ MatrixT3 c_matrix(c, m, n);
+ c_matrix.noalias() = a_matrix * b_matrix;
+ }
+};
+
+// If we have Apple's Accelerate framework, use their implementation of GEMM to
+// get a performance boost for float.
+#if defined(USE_ACCELERATE_GEMM)
+template <>
+class FastGemmFunctor<float, float, float> {
+ public:
+ void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
+ const float* b, size_t ldb, float* c, size_t ldc) {
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
+ lda, b, ldb, 0.0f, c, ldc);
+ }
+};
+#endif // USE_ACCELERATE_GEMM
diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h
index a7acb5e649..8870937422 100644
--- a/tensorflow/core/kernels/image_resizer_state.h
+++ b/tensorflow/core/kernels/image_resizer_state.h
@@ -49,12 +49,13 @@ struct ImageResizerState {
explicit ImageResizerState(bool align_corners)
: align_corners_(align_corners) {}
- // ValidateAndCreateOutput checks the bounds on the input tensors
+ // ValidateAndCalculateOutputSize checks the bounds on the input tensors
// and requested size, sets up some of the resizing state such as the
- // height_scale and width_scale, and allocates the output.
+ // height_scale and width_scale, and calculates the output size.
// If any of these operations fails, it sets an error status in
// the context, which the caller must check.
- void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
+ void ValidateAndCalculateOutputSize(OpKernelContext* context,
+ const Tensor& input) {
OP_REQUIRES(context, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
input.shape().DebugString()));
@@ -87,12 +88,18 @@ struct ImageResizerState {
OP_REQUIRES(
context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
errors::InvalidArgument("input image must be of non-zero size"));
+ height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
+ width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
+ }
+
+ // Calculates all the required variables, and allocates the output.
+ void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
+ ValidateAndCalculateOutputSize(context, input);
+ if (!context->status().ok()) return;
OP_REQUIRES_OK(context, context->allocate_output(
0, TensorShape({input.dim_size(0), out_height,
out_width, input.dim_size(3)}),
&output));
- height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
- width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
}
int64 batch_size;
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index eae5187896..3baae914cb 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test {
test::SetOutputAttrs(params_.get(), &attrs);
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
+ params_.get()->resource_manager = device_.get()->resource_manager();
context_.reset(new OpKernelContext(params_.get()));
device_->Compute(kernel_.get(), context_.get());
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index cc374278e7..5daaf83133 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -425,6 +426,46 @@ data_format: Specify the data format of the input and output data. With the
[batch, in_channels, in_height, in_width].
)doc");
+REGISTER_OP("FusedResizeAndPadConv2D")
+ .Input("input: T")
+ .Input("size: int32")
+ .Input("paddings: int32")
+ .Input("filter: T")
+ .Output("output: T")
+ .Attr("T: {half, float, double}")
+ .Attr("resize_align_corners: bool = false")
+ .Attr(GetMirrorPadModeAttrString())
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Doc(R"doc(
+Performs a resize and padding as a preprocess during a convolution.
+
+It's often possible to do spatial transformations more efficiently as part of
+the packing stage of a convolution, so this op allows for an optimized
+implementation where these stages are fused together. This prevents the need to
+write out the intermediate results as whole tensors, reducing memory pressure,
+and we can get some latency gains by merging the transformation calculations.
+The data_format attribute for Conv2D isn't supported by this op, and defaults to
+'NHWC' order.
+Internally this op uses a single per-graph scratch buffer, which means that it
+will block if multiple versions are being run in parallel. This is because this
+operator is primarily an optimization to minimize memory usage.
+
+input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+ new size for the images.
+paddings: A two-column matrix specifying the padding sizes. The number of
+ rows must be the same as the rank of `input`.
+filter: 4-D with shape
+ `[filter_height, filter_width, in_channels, out_channels]`.
+resize_align_corners: If true, rescale input by (new_height - 1) / (height - 1),
+ which exactly aligns the 4 corners of images and resized images. If false, rescale
+ by new_height / height. Treat similarly the width dimension.
+strides: 1-D of length 4. The stride of the sliding window for each dimension
+ of `input`. Must be in the same order as the dimension specified with format.
+padding: The type of padding algorithm to use.
+ )doc");
+
// --------------------------------------------------------------------------
REGISTER_OP("DepthwiseConv2dNative")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index e14ca1a559..7bc3ffe25b 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -844,6 +844,71 @@ ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
+@ops.RegisterShape("FusedResizeAndPadConv2D")
+def _FusedResizeAndPadConv2DShape(op):
+ """Shape function for FusedResizeAndPadConv2D op."""
+ # The bilinear resize shape calculation.
+ input_shape = op.inputs[0].get_shape().with_rank(4)
+ unused_size_shape = op.inputs[1].get_shape().merge_with([2])
+ size = tensor_util.constant_value(op.inputs[1])
+ if size is not None:
+ height = size[0]
+ width = size[1]
+ else:
+ height = None
+ width = None
+ resized_shape = tensor_shape.TensorShape(
+ [input_shape[0], height, width, input_shape[3]])
+
+ # Calculates the effect of the padding.
+ paddings_shape = op.inputs[2].get_shape().with_rank(2)
+ resized_shape = resized_shape.with_rank(paddings_shape[0].value)
+ paddings_shape = paddings_shape.merge_with(
+ tensor_shape.matrix(resized_shape.ndims, 2))
+ paddings = tensor_util.constant_value(op.inputs[2])
+ if paddings is None:
+ padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims)
+ else:
+ output_dims = []
+ for i, dim in enumerate(resized_shape.dims):
+ if paddings[i, 0] < 0 or paddings[i, 1] < 0:
+ raise ValueError("paddings must be non-negative")
+ output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
+ padded_shape = tensor_shape.TensorShape(output_dims)
+
+ # Finally work out the convolution's effect.
+ filter_shape = op.inputs[3].get_shape().with_rank(4)
+
+ batch_size = padded_shape[0]
+ in_rows = padded_shape[1]
+ in_cols = padded_shape[2]
+
+ filter_rows = filter_shape[0]
+ filter_cols = filter_shape[1]
+ depth_out = filter_shape[3]
+ # Check that the input depths are compatible.
+ padded_shape[3].assert_is_compatible_with(filter_shape[2])
+
+ stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+ if stride_b != 1 or stride_d != 1:
+ raise ValueError("Current implementation does not yet support "
+ "strides in the batch and depth dimensions.")
+ # TODO(mrry,shlens): Raise an error if the stride would cause
+ # information in the input to be ignored. This will require a change
+ # in the kernel implementation.
+ padding = op.get_attr("padding")
+ out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
+ filter_rows,
+ filter_cols,
+ stride_r,
+ stride_c,
+ padding)
+
+ output_shape = [batch_size, out_rows, out_cols, depth_out]
+ return [tensor_shape.TensorShape(output_shape)]
+
+
@ops.RegisterShape("MaxPoolWithArgmax")
def _MaxPoolWithArgMaxShape(op):
"""Shape function for MaxPoolWithArgmax op."""
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index a330ff7c50..9c115f53be 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -27,6 +27,8 @@ the network is used only for inference. These include:
- Folding batch normalization ops into the pre-calculated weights.
+ - Fusing common operations into unified versions.
+
This script takes a frozen GraphDef file (where the weight variables have been
converted into constants by the freeze_graph script) and outputs a new GraphDef
with the optimizations applied.
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index 4eb138d97d..1cb5ba1625 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -27,6 +27,8 @@ the network is used only for inference. These include:
- Folding batch normalization ops into the pre-calculated weights.
+ - Fusing common operations into unified versions.
+
This script takes a frozen GraphDef file (where the weight variables have been
converted into constants by the freeze_graph script) and outputs a new GraphDef
with the optimizations applied.
@@ -37,8 +39,8 @@ bazel build tensorflow/python/tools:optimize_for_inference && \
bazel-bin/tensorflow/python/tools/optimize_for_inference \
--input_graph=some_graph_def.pb \
--output_graph=/tmp/optimized_graph.pb \
---input_node_names=Mul
---output_node_names=softmax
+--input_names=Mul \
+--output_names=softmax
"""
@@ -74,13 +76,42 @@ def optimize_for_inference(input_graph_def, input_node_names,
Returns:
An optimized version of the input graph.
"""
- stripped_graph_def = strip_unused_lib.strip_unused(input_graph_def,
- input_node_names,
- output_node_names,
- placeholder_type_enum)
- detrained_graph_def = graph_util.remove_training_nodes(stripped_graph_def)
- folded_graph_def = fold_batch_norms(detrained_graph_def)
- return folded_graph_def
+ ensure_graph_is_valid(input_graph_def)
+ optimized_graph_def = input_graph_def
+ optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
+ input_node_names,
+ output_node_names,
+ placeholder_type_enum)
+ optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
+ optimized_graph_def = fold_batch_norms(optimized_graph_def)
+ optimized_graph_def = fuse_resize_and_conv(optimized_graph_def)
+ ensure_graph_is_valid(optimized_graph_def)
+ return optimized_graph_def
+
+
+def ensure_graph_is_valid(graph_def):
+ """Makes sure that the graph is internally consistent.
+
+ Checks basic properties of the graph def and raises an exception if there are
+ input references to missing nodes, duplicated names, or other logic errors.
+
+ Args:
+ graph_def: Definition of a graph to be checked.
+
+ Raises:
+ ValueError: If the graph is incorrectly constructed.
+ """
+ node_map = {}
+ for node in graph_def.node:
+ if node.name not in node_map.keys():
+ node_map[node.name] = node
+ else:
+ raise ValueError("Duplicate node names detected for ", node.name)
+ for node in graph_def.node:
+ for input_name in node.input:
+ input_node_name = node_name_from_input(input_name)
+ if input_node_name not in node_map.keys():
+ raise ValueError("Input for ", node.name, " not found: ", input_name)
def node_name_from_input(node_name):
@@ -161,7 +192,7 @@ def fold_batch_norms(input_graph_def):
if node.name not in input_node_map.keys():
input_node_map[node.name] = node
else:
- raise ValueError("Duplicate node names detected.")
+ raise ValueError("Duplicate node names detected for ", node.name)
nodes_to_skip = {}
new_ops = []
@@ -303,3 +334,94 @@ def fold_batch_norms(input_graph_def):
result_graph_def.node.extend(new_ops)
return result_graph_def
+
+
+def fuse_resize_and_conv(input_graph_def):
+ """Merges preceding resize and mirror pad ops into a specialized convolution.
+
+ There's a common pattern of enlarging the input to a convolution using a
+ resize operation, and also using MirrorPad to extend the boundaries to that
+ zero edge pixels don't bleed inwards when convolving. This routine looks for
+ that pattern of operations, and fuses them together into a Conv2DWithResizeOp.
+
+ Args:
+ input_graph_def: A GraphDef containing a model.
+
+ Returns:
+ Modified graph with resize and pad ops merged.
+
+ Raises:
+ ValueError: If the graph is badly formed with duplicate node names.
+ """
+
+ input_node_map = {}
+ for node in input_graph_def.node:
+ if node.name not in input_node_map.keys():
+ input_node_map[node.name] = node
+ else:
+ raise ValueError("Duplicate node names detected for ", node.name)
+
+ nodes_to_skip = {}
+ new_ops = []
+ for node in input_graph_def.node:
+
+ if node.op != "Conv2D":
+ continue
+ conv_op = node
+
+ input_op = node_from_map(input_node_map, conv_op.input[0])
+ if input_op.op == "MirrorPad":
+ mirror_pad_op = input_op
+ resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
+ else:
+ mirror_pad_op = None
+ resize_op = input_op
+
+ if resize_op.op != "ResizeBilinear":
+ continue
+
+ nodes_to_skip[conv_op.name] = True
+ if mirror_pad_op:
+ nodes_to_skip[mirror_pad_op.name] = True
+ nodes_to_skip[resize_op.name] = True
+
+ fused_conv_op = tf.NodeDef()
+ fused_conv_op.op = "FusedResizeAndPadConv2D"
+ fused_conv_op.name = conv_op.name
+ if mirror_pad_op:
+ mirror_paddings_name = mirror_pad_op.input[1]
+ mirror_paddings_mode = mirror_pad_op.attr["mode"]
+ else:
+ # If there was no MirrorPad op, then create settings that make the padding
+ # stage of the fused operation a no-op.
+ paddings_op = tf.NodeDef()
+ paddings_op.op = "Const"
+ paddings_op.name = conv_op.name + "_dummy_paddings"
+ paddings_op.attr["dtype"].CopyFrom(tf.AttrValue(
+ type=tf.int32.as_datatype_enum))
+ paddings_op.attr["value"].CopyFrom(tf.AttrValue(
+ tensor=tensor_util.make_tensor_proto(
+ [0, 0, 0, 0, 0, 0, 0, 0], tf.int32, [4, 2])))
+ new_ops.extend([paddings_op])
+ mirror_paddings_name = paddings_op.name
+ mirror_paddings_mode = tf.AttrValue(s=b"REFLECT")
+ fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1],
+ mirror_paddings_name, conv_op.input[1]])
+ fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
+ fused_conv_op.attr["resize_align_corners"].CopyFrom(
+ resize_op.attr["align_corners"])
+ fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
+ fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
+ fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
+ new_ops.extend([fused_conv_op])
+
+ result_graph_def = tf.GraphDef()
+ for node in input_graph_def.node:
+ if node.name in nodes_to_skip:
+ continue
+ new_node = tf.NodeDef()
+ new_node.CopyFrom(node)
+ result_graph_def.node.extend([new_node])
+
+ result_graph_def.node.extend(new_ops)
+ return result_graph_def
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index 61644fe9c9..d92d7ab8c7 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
shape=shape)))
def testOptimizeForInference(self):
+ unused_constant_name = "unused_constant"
unconnected_add_name = "unconnected_add"
a_constant_name = "a_constant"
b_constant_name = "b_constant"
@@ -64,9 +65,14 @@ class OptimizeForInferenceTest(tf.test.TestCase):
add_name = "add"
unused_output_add_name = "unused_output_add"
graph_def = tf.GraphDef()
+ unused_constant = self.create_constant_node_def(unused_constant_name,
+ value=0,
+ dtype=tf.float32,
+ shape=[])
+ graph_def.node.extend([unused_constant])
unconnected_add_node = self.create_node_def("Add", unconnected_add_name,
- ["no_such_node",
- "no_such_node"])
+ [unused_constant_name,
+ unused_constant_name])
self.set_attr_dtype(unconnected_add_node, "T", tf.float32)
graph_def.node.extend([unconnected_add_node])
a_constant = self.create_constant_node_def(a_constant_name,
@@ -160,6 +166,65 @@ class OptimizeForInferenceTest(tf.test.TestCase):
for node in optimized_graph_def.node:
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
+ def testFuseResizePadAndConv(self):
+ with self.test_session() as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
+ dtype=tf.float32)
+ resize_op = tf.image.resize_bilinear(input_op, [12, 4],
+ align_corners=False)
+ pad_op = tf.pad(resize_op, [[0, 0], [1, 1], [2, 2], [0, 0]],
+ mode="REFLECT")
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
+ dtype=tf.float32)
+ tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1],
+ padding="VALID", name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
+ original_graph_def)
+
+ with self.test_session() as sess:
+ _ = tf.import_graph_def(optimized_graph_def, input_map={},
+ name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(original_result, optimized_result)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("Conv2D", node.op)
+ self.assertNotEqual("MirrorPad", node.op)
+ self.assertNotEqual("ResizeBilinear", node.op)
+
+ def testFuseResizeAndConv(self):
+ with self.test_session() as sess:
+ inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+ input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
+ dtype=tf.float32)
+ resize_op = tf.image.resize_bilinear(input_op, [12, 4],
+ align_corners=False)
+ weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+ weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
+ dtype=tf.float32)
+ tf.nn.conv2d(resize_op, weights_op, [1, 1, 1, 1],
+ padding="VALID", name="output")
+ original_graph_def = sess.graph_def
+ original_result = sess.run(["output:0"])
+ optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
+ original_graph_def)
+
+ with self.test_session() as sess:
+ _ = tf.import_graph_def(optimized_graph_def, input_map={},
+ name="optimized")
+ optimized_result = sess.run(["optimized/output:0"])
+
+ self.assertAllClose(original_result, optimized_result)
+
+ for node in optimized_graph_def.node:
+ self.assertNotEqual("Conv2D", node.op)
+ self.assertNotEqual("ResizeBilinear", node.op)
+
if __name__ == "__main__":
tf.test.main()