aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Pete Warden <petewarden@google.com>2016-12-16 11:52:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-16 12:05:52 -0800
commit05585d42b6929a48235110f31e5e14d5bfa088df (patch)
tree71ce8c4dc4c0e9c08fdd138f5e3e5fe8f295364b
parent7b6111893a988132e627ec56d0885475261c101d (diff)
Speed up im2col for QuantizedConv2D by breaking work into chunks
Change: 142282720
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc331
1 files changed, 184 insertions, 147 deletions
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index 13fdc1e1f0..2217c2acca 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "public/gemmlowp.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/conv_ops.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/quantization_utils.h"
@@ -49,7 +50,7 @@ namespace tensorflow {
template <class T1, class T2, class T3>
class ReferenceConvFunctor {
public:
- void operator()(OpKernelContext* op_context, const T1* input_data,
+ void operator()(OpKernelContext* context, const T1* input_data,
int input_batches, int input_height, int input_width,
int input_depth, int input_offset, const T2* filter_data,
int filter_height, int filter_width, int filter_count,
@@ -181,6 +182,18 @@ class ReferenceConvFunctor {
}
};
+// We don't want to allocate a buffer to hold all the patches if the size is
+// going to be extremely large, so break it into chunks if it's bigger than
+// a limit. Each chunk will be processed serially, so we can refill the
+// buffer for the next chunk and reuse it, keeping maximum memory size down.
+// In this case, we've picked 16 megabytes as a reasonable limit for Android and
+// other platforms using Eigen, and 1MB for Apple devices, from experimentation.
+#if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM)
+const size_t kMaxChunkSize = (1 * 1024 * 1024);
+#else
+const size_t kMaxChunkSize = (16 * 1024 * 1024);
+#endif
+
// Implements convolution as a two stage process, first packing the patches of
// the input image into columns (im2col) and then running GEMM to produce the
// final result.
@@ -189,7 +202,7 @@ class ReferenceConvFunctor {
template <class T1, class T2, class T3>
class Im2ColConvFunctor {
public:
- void operator()(OpKernelContext* op_context, const T1* input_data,
+ void operator()(OpKernelContext* context, const T1* input_data,
int input_batches, int input_height, int input_width,
int input_depth, int input_offset, const T2* filter_data,
int filter_height, int filter_width, int filter_count,
@@ -202,8 +215,8 @@ class Im2ColConvFunctor {
if (warning_count < 10) {
++warning_count;
LOG(WARNING)
- << "For kernel '" << op_context->op_kernel().name()
- << "' from input '" << op_context->op_kernel().def().input(0)
+ << "For kernel '" << context->op_kernel().name() << "' from input '"
+ << context->op_kernel().def().input(0)
<< "': Zero is not representable in the quantized range used by the"
<< " input. This means QuantizedConv2d has to fall back to a slow"
<< " implementation, since the border of zero values can't be"
@@ -211,7 +224,7 @@ class Im2ColConvFunctor {
<< " avoid this situation.";
}
ReferenceConvFunctor<T1, T2, T3> conv_functor;
- conv_functor(op_context, input_data, input_batches, input_height,
+ conv_functor(context, input_data, input_batches, input_height,
input_width, input_depth, input_offset, filter_data,
filter_height, filter_width, filter_count, filter_offset,
stride, padding, output_data, output_height, output_width,
@@ -247,155 +260,179 @@ class Im2ColConvFunctor {
// by the width, then the height. This is the standard memory order in the
// image world if it helps to visualize it.
const int filter_value_count = filter_width * filter_height * input_depth;
- const int patch_count = input_batches * output_width * output_height;
- const int im2col_size = patch_count * filter_value_count;
+ const int64 patches_per_chunk =
+ kMaxChunkSize / (filter_value_count * sizeof(T1));
+ const int64 chunk_value_count =
+ (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1);
// TODO(petewarden) - Memory allocation can be very slow on Android. Can we
// optimize this by keeping the scratch buffer around?
- std::unique_ptr<T1[]> im2col_buffer(new T1[im2col_size]);
-
- for (int batch = 0; batch < input_batches; ++batch) {
- const T1* input_batch_start =
- input_data + (batch * input_height * input_width * input_depth);
- for (int out_y = 0; out_y < output_height; ++out_y) {
+ // Because memory allocation is very expensive on mobile platforms, try to
+ // allocate a persistent buffer that will be kept around between calls. We
+ // use TensorFlow's resource management to ensure that the memory will be
+ // released when the session is over.
+ Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource;
+ std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)>
+ creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) {
+ *resource = new Im2ColBufferResource<T1, chunk_value_count>();
+ return Status::OK();
+ };
+ OP_REQUIRES_OK(
+ context,
+ context->resource_manager()->LookupOrCreate(
+ "Conv2d", "im2col_buffer", &im2col_buffer_resource, creator));
+ // This means that multiple ops can't be run simultaneously on different
+ // threads, because we have a single shared resource. The platforms this is
+ // aimed at have intra-op parallelism as their focus though, so it shouldn't
+ // be an issue.
+ mutex_lock lock_buffer(im2col_buffer_resource->mu);
+ core::ScopedUnref unref_buffer(im2col_buffer_resource);
+ T1* im2col_buffer = im2col_buffer_resource->data;
+
+ const int64 patch_count = (input_batches * output_height * output_width);
+ const int64 chunk_count =
+ (patch_count + (patches_per_chunk - 1)) / patches_per_chunk;
+ for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) {
+ const int64 patch_index_start = chunk_index * patches_per_chunk;
+ const int64 patch_index_end =
+ std::min(patch_index_start + patches_per_chunk, patch_count);
+ for (int64 patch_index = patch_index_start; patch_index < patch_index_end;
+ ++patch_index) {
+ const int64 batch = patch_index / (output_height * output_width);
+ const int64 out_y = (patch_index / output_width) % output_height;
+ const int64 out_x = patch_index % output_width;
+ const T1* input_batch_start =
+ input_data + (batch * input_height * input_width * input_depth);
const int in_y_origin = (out_y * stride) - filter_top_offset;
- for (int out_x = 0; out_x < output_width; ++out_x) {
- const int in_x_origin = (out_x * stride) - filter_left_offset;
- const int patch_index = (batch * output_width * output_height) +
- (out_y * output_width) + out_x;
- T1* im2col_patch_start =
- im2col_buffer.get() + (patch_index * filter_value_count);
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
- const int in_y = in_y_origin + filter_y;
- T1* im2col_row_start =
- im2col_patch_start + (filter_y * filter_width * input_depth);
- // If we're off the top or the bottom of the input, fill the whole
- // row with zeroes.
- if ((in_y < 0) || (in_y >= input_height)) {
- T1* im2col_row_end =
- im2col_row_start + (filter_width * input_depth);
- // We'll be subtracting this offset during the calculations
- // so to get an actual zero after that bias we need to set
- // it to input_offset here.
- std::fill(im2col_row_start, im2col_row_end, input_offset);
- } else {
- // What we're doing here is trying to copy and fill the im2col
- // buffer as efficiently as possible, using functions to set or
- // duplicate values en masse. We know we don't have to worry about
- // vertical edges because we dealt with that case above, so we
- // just need to handle filters that overlap the left or right
- // edges. Here's what that looks like:
- //
- // < left_zero_count > < center_copy_count > < right_zero_count >
- // +------------------+---------------------+--------------------+
- // | (filter) | (image) | (filter) |
- // +------------------+---------------------+--------------------+
- // in_x_origin 0 input_width in_x_end
- //
- // In reality it's unlikely that a filter patch will be wider
- // than an input, but this shows all the edge cases.
- // We use std::fill() to set the left and right sections to zeroes
- // and std::copy() to copy over the input data for the center.
- const int in_x_end = in_x_origin + filter_width;
- const int left_zero_count = std::max(0, 0 - in_x_origin);
- const int right_zero_count = std::max(0, in_x_end - input_width);
- const int center_copy_count =
- filter_width - (left_zero_count + right_zero_count);
- if (left_zero_count > 0) {
- T1* im2col_left_start = im2col_row_start;
- T1* im2col_left_end =
- im2col_left_start + (left_zero_count * input_depth);
- std::fill(im2col_left_start, im2col_left_end, input_offset);
- }
- if (center_copy_count > 0) {
- const T1* input_row_start =
- input_batch_start + (in_y * input_width * input_depth) +
- (std::max(0, in_x_origin) * input_depth);
- const T1* input_row_end =
- input_row_start + (center_copy_count * input_depth);
- T1* im2col_center_start =
- im2col_row_start + (left_zero_count * input_depth);
- std::copy(input_row_start, input_row_end, im2col_center_start);
- }
- if (right_zero_count > 0) {
- T1* im2col_right_start =
- im2col_row_start +
- ((left_zero_count + center_copy_count) * input_depth);
- T1* im2col_right_end =
- im2col_right_start + (right_zero_count * input_depth);
- std::fill(im2col_right_start, im2col_right_end, input_offset);
- }
+ const int in_x_origin = (out_x * stride) - filter_left_offset;
+ const int patch_index_within_chunk = patch_index % patches_per_chunk;
+ T1* im2col_patch_start =
+ im2col_buffer + (patch_index_within_chunk * filter_value_count);
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ const int in_y = in_y_origin + filter_y;
+ T1* im2col_row_start =
+ im2col_patch_start + (filter_y * filter_width * input_depth);
+ // If we're off the top or the bottom of the input, fill the
+ // whole row with zeroes.
+ if ((in_y < 0) || (in_y >= input_height)) {
+ T1* im2col_row_end =
+ im2col_row_start + (filter_width * input_depth);
+ std::fill(im2col_row_start, im2col_row_end, input_offset);
+ } else {
+ // What we're doing here is trying to copy and fill the im2col
+ // buffer as efficiently as possible, using functions to set or
+ // duplicate values en masse. We know we don't have to worry about
+ // vertical edges because we dealt with that case above, so we
+ // just need to handle filters that overlap the left or right
+ // edges. Here's what that looks like:
+ //
+ // < left_zero_count > < center_copy_count > < right_zero_count >
+ // +------------------+---------------------+--------------------+
+ // | (filter) | (image) | (filter) |
+ // +------------------+---------------------+--------------------+
+ // in_x_origin 0 input_width in_x_end
+ //
+ // In reality it's unlikely that a filter patch will be wider
+ // than an input, but this shows all the edge cases.
+ // We use std::fill() to set the left and right sections to zeroes
+ // and std::copy() to copy over the input data for the center.
+ const int in_x_end = in_x_origin + filter_width;
+ const int left_zero_count = std::max(0, 0 - in_x_origin);
+ const int right_zero_count = std::max(0, in_x_end - input_width);
+ const int center_copy_count =
+ filter_width - (left_zero_count + right_zero_count);
+ if (left_zero_count > 0) {
+ T1* im2col_left_start = im2col_row_start;
+ T1* im2col_left_end =
+ im2col_left_start + (left_zero_count * input_depth);
+ std::fill(im2col_left_start, im2col_left_end, input_offset);
+ }
+ if (center_copy_count > 0) {
+ const T1* input_row_start =
+ input_batch_start + (in_y * input_width * input_depth) +
+ (std::max(0, in_x_origin) * input_depth);
+ const T1* input_row_end =
+ input_row_start + (center_copy_count * input_depth);
+ T1* im2col_center_start =
+ im2col_row_start + (left_zero_count * input_depth);
+ std::copy(input_row_start, input_row_end, im2col_center_start);
+ }
+ if (right_zero_count > 0) {
+ T1* im2col_right_start =
+ im2col_row_start +
+ ((left_zero_count + center_copy_count) * input_depth);
+ T1* im2col_right_end =
+ im2col_right_start + (right_zero_count * input_depth);
+ std::fill(im2col_right_start, im2col_right_end, input_offset);
}
}
}
}
- }
-
- CHECK_GT(patch_count, 0);
- CHECK_GT(filter_count, 0);
- CHECK_GT(filter_value_count, 0);
-
- const bool transpose_a = false;
- const bool transpose_b = false;
- const bool transpose_c = false;
- const int m = patch_count;
- const int n = filter_count;
- const int k = filter_value_count;
- const int lda = filter_value_count;
- const int ldb = filter_count;
- const int ldc = filter_count;
-
- if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
- std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
- (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
- (transpose_c == false)) {
- meta::QuantizedGemm(op_context, transpose_a, transpose_b,
- im2col_buffer.get(), filter_data, output_data, m, n,
- k, -input_offset, -filter_offset, lda, ldb, ldc);
- } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
- std::is_same<T3, qint32>() && (output_offset == 0) &&
- (output_mult == 1) && (output_shift == 0)) {
- // The gemmlowp optimized library only works for a particular set of data
- // types, so check if we meet those requirements and
- // fall back to a slower reference implementation if not.
- const uint8* im2col_data_as_uint8 = &(im2col_buffer.get()->value);
- const uint8* filter_data_as_uint8 = &(filter_data->value);
- int32* output_data_as_int32 = &(output_data->value);
- // All of the transpose_* variables are currently compile-time consts, so
- // we could just hard-code these values too, but that would break if
- // anybody changed those values in the future (e.g. to match the ability
- // of MatMul to specify them as attributes). We're using a verbose
- // approach of deriving the order values from the transpose variables to
- // be able to catch any changes like that.
- static const gemmlowp::MapOrder ResultOrder =
- !transpose_c ? gemmlowp::MapOrder::RowMajor
- : gemmlowp::MapOrder::ColMajor;
- static const gemmlowp::MapOrder LhsOrder =
- !transpose_a ? gemmlowp::MapOrder::RowMajor
- : gemmlowp::MapOrder::ColMajor;
- static const gemmlowp::MapOrder RhsOrder =
- !transpose_b ? gemmlowp::MapOrder::RowMajor
- : gemmlowp::MapOrder::ColMajor;
- gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
- im2col_data_as_uint8, m, k, lda);
- gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
- filter_data_as_uint8, k, n, ldb);
- gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
- output_data_as_int32, m, n, ldc);
- const std::tuple<> empty_pipeline = {};
-
- auto& worker_threads =
- *(op_context->device()->tensorflow_cpu_worker_threads());
- TensorflowGemmContext context(worker_threads.num_threads,
- worker_threads.workers);
- gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
- gemmlowp::DefaultL8R8BitDepthParams>(
- &context, lhs, rhs, &result, -input_offset, -filter_offset,
- empty_pipeline);
- } else {
- ReferenceGemm<T1, T2, T3>(transpose_a, transpose_b, transpose_c, m, n, k,
- im2col_buffer.get(), input_offset, lda,
- filter_data, filter_offset, ldb, output_data,
- output_shift, output_offset, output_mult, ldc);
+ // Now we've assembled a set of image patches into a matrix, apply a
+ // GEMM matrix multiply of the patches as rows, times the filter
+ // weights in columns, to get partial results in the output matrix.
+ const int how_many_patches = patch_index_end - patch_index_start;
+ const bool transpose_a = false;
+ const bool transpose_b = false;
+ const bool transpose_c = false;
+ const int m = how_many_patches;
+ const int n = filter_count;
+ const int k = filter_value_count;
+ const int lda = filter_value_count;
+ const int ldb = filter_count;
+ const int ldc = filter_count;
+ if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() &&
+ std::is_same<T2, quint8>() && std::is_same<T3, qint32>() &&
+ (output_offset == 0) && (output_mult == 1) && (output_shift == 0) &&
+ (transpose_c == false)) {
+ meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer,
+ filter_data, output_data, m, n, k, -input_offset,
+ -filter_offset, lda, ldb, ldc);
+ } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() &&
+ std::is_same<T3, qint32>() && (output_offset == 0) &&
+ (output_mult == 1) && (output_shift == 0)) {
+ // The gemmlowp optimized library only works for a particular set of
+ // data types, so check if we meet those requirements and fall back to a
+ // slower reference implementation if not.
+ const uint8* im2col_data_as_uint8 = &(im2col_buffer->value);
+ const uint8* filter_data_as_uint8 = &(filter_data->value);
+ int32* output_data_as_int32 = &(output_data->value);
+ // All of the transpose_* variables are currently compile-time consts,
+ // so we could just hard-code these values too, but that would break if
+ // anybody changed those values in the future (e.g. to match the ability
+ // of MatMul to specify them as attributes). We're using a verbose
+ // approach of deriving the order values from the transpose variables to
+ // be able to catch any changes like that.
+ static const gemmlowp::MapOrder ResultOrder =
+ !transpose_c ? gemmlowp::MapOrder::RowMajor
+ : gemmlowp::MapOrder::ColMajor;
+ static const gemmlowp::MapOrder LhsOrder =
+ !transpose_a ? gemmlowp::MapOrder::RowMajor
+ : gemmlowp::MapOrder::ColMajor;
+ static const gemmlowp::MapOrder RhsOrder =
+ !transpose_b ? gemmlowp::MapOrder::RowMajor
+ : gemmlowp::MapOrder::ColMajor;
+ gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs(
+ im2col_data_as_uint8, m, k, lda);
+ gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs(
+ filter_data_as_uint8, k, n, ldb);
+ gemmlowp::MatrixMap<std::int32_t, ResultOrder> result(
+ output_data_as_int32, m, n, ldc);
+ const std::tuple<> empty_pipeline = {};
+
+ auto& worker_threads =
+ *(context->device()->tensorflow_cpu_worker_threads());
+ TensorflowGemmContext context(worker_threads.num_threads,
+ worker_threads.workers);
+ gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t,
+ gemmlowp::DefaultL8R8BitDepthParams>(
+ &context, lhs, rhs, &result, -input_offset, -filter_offset,
+ empty_pipeline);
+ } else {
+ ReferenceGemm<T1, T2, T3>(
+ transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer,
+ input_offset, lda, filter_data, filter_offset, ldb, output_data,
+ output_shift, output_offset, output_mult, ldc);
+ }
}
}
};