diff options
author | 2016-12-16 11:52:55 -0800 | |
---|---|---|
committer | 2016-12-16 12:05:52 -0800 | |
commit | 05585d42b6929a48235110f31e5e14d5bfa088df (patch) | |
tree | 71ce8c4dc4c0e9c08fdd138f5e3e5fe8f295364b | |
parent | 7b6111893a988132e627ec56d0885475261c101d (diff) |
Speed up im2col for QuantizedConv2D by breaking work into chunks
Change: 142282720
-rw-r--r-- | tensorflow/core/kernels/quantized_conv_ops.cc | 331 |
1 files changed, 184 insertions, 147 deletions
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc index 13fdc1e1f0..2217c2acca 100644 --- a/tensorflow/core/kernels/quantized_conv_ops.cc +++ b/tensorflow/core/kernels/quantized_conv_ops.cc @@ -23,6 +23,7 @@ limitations under the License. #include "public/gemmlowp.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/conv_ops.h" #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/quantization_utils.h" @@ -49,7 +50,7 @@ namespace tensorflow { template <class T1, class T2, class T3> class ReferenceConvFunctor { public: - void operator()(OpKernelContext* op_context, const T1* input_data, + void operator()(OpKernelContext* context, const T1* input_data, int input_batches, int input_height, int input_width, int input_depth, int input_offset, const T2* filter_data, int filter_height, int filter_width, int filter_count, @@ -181,6 +182,18 @@ class ReferenceConvFunctor { } }; +// We don't want to allocate a buffer to hold all the patches if the size is +// going to be extremely large, so break it into chunks if it's bigger than +// a limit. Each chunk will be processed serially, so we can refill the +// buffer for the next chunk and reuse it, keeping maximum memory size down. +// In this case, we've picked 16 megabytes as a reasonable limit for Android and +// other platforms using Eigen, and 1MB for Apple devices, from experimentation. +#if defined(__APPLE__) && defined(IS_MOBILE_PLATFORM) +const size_t kMaxChunkSize = (1 * 1024 * 1024); +#else +const size_t kMaxChunkSize = (16 * 1024 * 1024); +#endif + // Implements convolution as a two stage process, first packing the patches of // the input image into columns (im2col) and then running GEMM to produce the // final result. @@ -189,7 +202,7 @@ class ReferenceConvFunctor { template <class T1, class T2, class T3> class Im2ColConvFunctor { public: - void operator()(OpKernelContext* op_context, const T1* input_data, + void operator()(OpKernelContext* context, const T1* input_data, int input_batches, int input_height, int input_width, int input_depth, int input_offset, const T2* filter_data, int filter_height, int filter_width, int filter_count, @@ -202,8 +215,8 @@ class Im2ColConvFunctor { if (warning_count < 10) { ++warning_count; LOG(WARNING) - << "For kernel '" << op_context->op_kernel().name() - << "' from input '" << op_context->op_kernel().def().input(0) + << "For kernel '" << context->op_kernel().name() << "' from input '" + << context->op_kernel().def().input(0) << "': Zero is not representable in the quantized range used by the" << " input. This means QuantizedConv2d has to fall back to a slow" << " implementation, since the border of zero values can't be" @@ -211,7 +224,7 @@ class Im2ColConvFunctor { << " avoid this situation."; } ReferenceConvFunctor<T1, T2, T3> conv_functor; - conv_functor(op_context, input_data, input_batches, input_height, + conv_functor(context, input_data, input_batches, input_height, input_width, input_depth, input_offset, filter_data, filter_height, filter_width, filter_count, filter_offset, stride, padding, output_data, output_height, output_width, @@ -247,155 +260,179 @@ class Im2ColConvFunctor { // by the width, then the height. This is the standard memory order in the // image world if it helps to visualize it. const int filter_value_count = filter_width * filter_height * input_depth; - const int patch_count = input_batches * output_width * output_height; - const int im2col_size = patch_count * filter_value_count; + const int64 patches_per_chunk = + kMaxChunkSize / (filter_value_count * sizeof(T1)); + const int64 chunk_value_count = + (kMaxChunkSize + (sizeof(T1) - 1)) / sizeof(T1); // TODO(petewarden) - Memory allocation can be very slow on Android. Can we // optimize this by keeping the scratch buffer around? - std::unique_ptr<T1[]> im2col_buffer(new T1[im2col_size]); - - for (int batch = 0; batch < input_batches; ++batch) { - const T1* input_batch_start = - input_data + (batch * input_height * input_width * input_depth); - for (int out_y = 0; out_y < output_height; ++out_y) { + // Because memory allocation is very expensive on mobile platforms, try to + // allocate a persistent buffer that will be kept around between calls. We + // use TensorFlow's resource management to ensure that the memory will be + // released when the session is over. + Im2ColBufferResource<T1, chunk_value_count>* im2col_buffer_resource; + std::function<Status(Im2ColBufferResource<T1, chunk_value_count>**)> + creator = [](Im2ColBufferResource<T1, chunk_value_count>** resource) { + *resource = new Im2ColBufferResource<T1, chunk_value_count>(); + return Status::OK(); + }; + OP_REQUIRES_OK( + context, + context->resource_manager()->LookupOrCreate( + "Conv2d", "im2col_buffer", &im2col_buffer_resource, creator)); + // This means that multiple ops can't be run simultaneously on different + // threads, because we have a single shared resource. The platforms this is + // aimed at have intra-op parallelism as their focus though, so it shouldn't + // be an issue. + mutex_lock lock_buffer(im2col_buffer_resource->mu); + core::ScopedUnref unref_buffer(im2col_buffer_resource); + T1* im2col_buffer = im2col_buffer_resource->data; + + const int64 patch_count = (input_batches * output_height * output_width); + const int64 chunk_count = + (patch_count + (patches_per_chunk - 1)) / patches_per_chunk; + for (int64 chunk_index = 0; chunk_index < chunk_count; ++chunk_index) { + const int64 patch_index_start = chunk_index * patches_per_chunk; + const int64 patch_index_end = + std::min(patch_index_start + patches_per_chunk, patch_count); + for (int64 patch_index = patch_index_start; patch_index < patch_index_end; + ++patch_index) { + const int64 batch = patch_index / (output_height * output_width); + const int64 out_y = (patch_index / output_width) % output_height; + const int64 out_x = patch_index % output_width; + const T1* input_batch_start = + input_data + (batch * input_height * input_width * input_depth); const int in_y_origin = (out_y * stride) - filter_top_offset; - for (int out_x = 0; out_x < output_width; ++out_x) { - const int in_x_origin = (out_x * stride) - filter_left_offset; - const int patch_index = (batch * output_width * output_height) + - (out_y * output_width) + out_x; - T1* im2col_patch_start = - im2col_buffer.get() + (patch_index * filter_value_count); - for (int filter_y = 0; filter_y < filter_height; ++filter_y) { - const int in_y = in_y_origin + filter_y; - T1* im2col_row_start = - im2col_patch_start + (filter_y * filter_width * input_depth); - // If we're off the top or the bottom of the input, fill the whole - // row with zeroes. - if ((in_y < 0) || (in_y >= input_height)) { - T1* im2col_row_end = - im2col_row_start + (filter_width * input_depth); - // We'll be subtracting this offset during the calculations - // so to get an actual zero after that bias we need to set - // it to input_offset here. - std::fill(im2col_row_start, im2col_row_end, input_offset); - } else { - // What we're doing here is trying to copy and fill the im2col - // buffer as efficiently as possible, using functions to set or - // duplicate values en masse. We know we don't have to worry about - // vertical edges because we dealt with that case above, so we - // just need to handle filters that overlap the left or right - // edges. Here's what that looks like: - // - // < left_zero_count > < center_copy_count > < right_zero_count > - // +------------------+---------------------+--------------------+ - // | (filter) | (image) | (filter) | - // +------------------+---------------------+--------------------+ - // in_x_origin 0 input_width in_x_end - // - // In reality it's unlikely that a filter patch will be wider - // than an input, but this shows all the edge cases. - // We use std::fill() to set the left and right sections to zeroes - // and std::copy() to copy over the input data for the center. - const int in_x_end = in_x_origin + filter_width; - const int left_zero_count = std::max(0, 0 - in_x_origin); - const int right_zero_count = std::max(0, in_x_end - input_width); - const int center_copy_count = - filter_width - (left_zero_count + right_zero_count); - if (left_zero_count > 0) { - T1* im2col_left_start = im2col_row_start; - T1* im2col_left_end = - im2col_left_start + (left_zero_count * input_depth); - std::fill(im2col_left_start, im2col_left_end, input_offset); - } - if (center_copy_count > 0) { - const T1* input_row_start = - input_batch_start + (in_y * input_width * input_depth) + - (std::max(0, in_x_origin) * input_depth); - const T1* input_row_end = - input_row_start + (center_copy_count * input_depth); - T1* im2col_center_start = - im2col_row_start + (left_zero_count * input_depth); - std::copy(input_row_start, input_row_end, im2col_center_start); - } - if (right_zero_count > 0) { - T1* im2col_right_start = - im2col_row_start + - ((left_zero_count + center_copy_count) * input_depth); - T1* im2col_right_end = - im2col_right_start + (right_zero_count * input_depth); - std::fill(im2col_right_start, im2col_right_end, input_offset); - } + const int in_x_origin = (out_x * stride) - filter_left_offset; + const int patch_index_within_chunk = patch_index % patches_per_chunk; + T1* im2col_patch_start = + im2col_buffer + (patch_index_within_chunk * filter_value_count); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int in_y = in_y_origin + filter_y; + T1* im2col_row_start = + im2col_patch_start + (filter_y * filter_width * input_depth); + // If we're off the top or the bottom of the input, fill the + // whole row with zeroes. + if ((in_y < 0) || (in_y >= input_height)) { + T1* im2col_row_end = + im2col_row_start + (filter_width * input_depth); + std::fill(im2col_row_start, im2col_row_end, input_offset); + } else { + // What we're doing here is trying to copy and fill the im2col + // buffer as efficiently as possible, using functions to set or + // duplicate values en masse. We know we don't have to worry about + // vertical edges because we dealt with that case above, so we + // just need to handle filters that overlap the left or right + // edges. Here's what that looks like: + // + // < left_zero_count > < center_copy_count > < right_zero_count > + // +------------------+---------------------+--------------------+ + // | (filter) | (image) | (filter) | + // +------------------+---------------------+--------------------+ + // in_x_origin 0 input_width in_x_end + // + // In reality it's unlikely that a filter patch will be wider + // than an input, but this shows all the edge cases. + // We use std::fill() to set the left and right sections to zeroes + // and std::copy() to copy over the input data for the center. + const int in_x_end = in_x_origin + filter_width; + const int left_zero_count = std::max(0, 0 - in_x_origin); + const int right_zero_count = std::max(0, in_x_end - input_width); + const int center_copy_count = + filter_width - (left_zero_count + right_zero_count); + if (left_zero_count > 0) { + T1* im2col_left_start = im2col_row_start; + T1* im2col_left_end = + im2col_left_start + (left_zero_count * input_depth); + std::fill(im2col_left_start, im2col_left_end, input_offset); + } + if (center_copy_count > 0) { + const T1* input_row_start = + input_batch_start + (in_y * input_width * input_depth) + + (std::max(0, in_x_origin) * input_depth); + const T1* input_row_end = + input_row_start + (center_copy_count * input_depth); + T1* im2col_center_start = + im2col_row_start + (left_zero_count * input_depth); + std::copy(input_row_start, input_row_end, im2col_center_start); + } + if (right_zero_count > 0) { + T1* im2col_right_start = + im2col_row_start + + ((left_zero_count + center_copy_count) * input_depth); + T1* im2col_right_end = + im2col_right_start + (right_zero_count * input_depth); + std::fill(im2col_right_start, im2col_right_end, input_offset); } } } } - } - - CHECK_GT(patch_count, 0); - CHECK_GT(filter_count, 0); - CHECK_GT(filter_value_count, 0); - - const bool transpose_a = false; - const bool transpose_b = false; - const bool transpose_c = false; - const int m = patch_count; - const int n = filter_count; - const int k = filter_value_count; - const int lda = filter_value_count; - const int ldb = filter_count; - const int ldc = filter_count; - - if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && - std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && - (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && - (transpose_c == false)) { - meta::QuantizedGemm(op_context, transpose_a, transpose_b, - im2col_buffer.get(), filter_data, output_data, m, n, - k, -input_offset, -filter_offset, lda, ldb, ldc); - } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && - std::is_same<T3, qint32>() && (output_offset == 0) && - (output_mult == 1) && (output_shift == 0)) { - // The gemmlowp optimized library only works for a particular set of data - // types, so check if we meet those requirements and - // fall back to a slower reference implementation if not. - const uint8* im2col_data_as_uint8 = &(im2col_buffer.get()->value); - const uint8* filter_data_as_uint8 = &(filter_data->value); - int32* output_data_as_int32 = &(output_data->value); - // All of the transpose_* variables are currently compile-time consts, so - // we could just hard-code these values too, but that would break if - // anybody changed those values in the future (e.g. to match the ability - // of MatMul to specify them as attributes). We're using a verbose - // approach of deriving the order values from the transpose variables to - // be able to catch any changes like that. - static const gemmlowp::MapOrder ResultOrder = - !transpose_c ? gemmlowp::MapOrder::RowMajor - : gemmlowp::MapOrder::ColMajor; - static const gemmlowp::MapOrder LhsOrder = - !transpose_a ? gemmlowp::MapOrder::RowMajor - : gemmlowp::MapOrder::ColMajor; - static const gemmlowp::MapOrder RhsOrder = - !transpose_b ? gemmlowp::MapOrder::RowMajor - : gemmlowp::MapOrder::ColMajor; - gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs( - im2col_data_as_uint8, m, k, lda); - gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs( - filter_data_as_uint8, k, n, ldb); - gemmlowp::MatrixMap<std::int32_t, ResultOrder> result( - output_data_as_int32, m, n, ldc); - const std::tuple<> empty_pipeline = {}; - - auto& worker_threads = - *(op_context->device()->tensorflow_cpu_worker_threads()); - TensorflowGemmContext context(worker_threads.num_threads, - worker_threads.workers); - gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, - gemmlowp::DefaultL8R8BitDepthParams>( - &context, lhs, rhs, &result, -input_offset, -filter_offset, - empty_pipeline); - } else { - ReferenceGemm<T1, T2, T3>(transpose_a, transpose_b, transpose_c, m, n, k, - im2col_buffer.get(), input_offset, lda, - filter_data, filter_offset, ldb, output_data, - output_shift, output_offset, output_mult, ldc); + // Now we've assembled a set of image patches into a matrix, apply a + // GEMM matrix multiply of the patches as rows, times the filter + // weights in columns, to get partial results in the output matrix. + const int how_many_patches = patch_index_end - patch_index_start; + const bool transpose_a = false; + const bool transpose_b = false; + const bool transpose_c = false; + const int m = how_many_patches; + const int n = filter_count; + const int k = filter_value_count; + const int lda = filter_value_count; + const int ldb = filter_count; + const int ldc = filter_count; + if (meta::IsSupportedAndEnabled() && std::is_same<T1, quint8>() && + std::is_same<T2, quint8>() && std::is_same<T3, qint32>() && + (output_offset == 0) && (output_mult == 1) && (output_shift == 0) && + (transpose_c == false)) { + meta::QuantizedGemm(context, transpose_a, transpose_b, im2col_buffer, + filter_data, output_data, m, n, k, -input_offset, + -filter_offset, lda, ldb, ldc); + } else if (std::is_same<T1, quint8>() && std::is_same<T2, quint8>() && + std::is_same<T3, qint32>() && (output_offset == 0) && + (output_mult == 1) && (output_shift == 0)) { + // The gemmlowp optimized library only works for a particular set of + // data types, so check if we meet those requirements and fall back to a + // slower reference implementation if not. + const uint8* im2col_data_as_uint8 = &(im2col_buffer->value); + const uint8* filter_data_as_uint8 = &(filter_data->value); + int32* output_data_as_int32 = &(output_data->value); + // All of the transpose_* variables are currently compile-time consts, + // so we could just hard-code these values too, but that would break if + // anybody changed those values in the future (e.g. to match the ability + // of MatMul to specify them as attributes). We're using a verbose + // approach of deriving the order values from the transpose variables to + // be able to catch any changes like that. + static const gemmlowp::MapOrder ResultOrder = + !transpose_c ? gemmlowp::MapOrder::RowMajor + : gemmlowp::MapOrder::ColMajor; + static const gemmlowp::MapOrder LhsOrder = + !transpose_a ? gemmlowp::MapOrder::RowMajor + : gemmlowp::MapOrder::ColMajor; + static const gemmlowp::MapOrder RhsOrder = + !transpose_b ? gemmlowp::MapOrder::RowMajor + : gemmlowp::MapOrder::ColMajor; + gemmlowp::MatrixMap<const std::uint8_t, LhsOrder> lhs( + im2col_data_as_uint8, m, k, lda); + gemmlowp::MatrixMap<const std::uint8_t, RhsOrder> rhs( + filter_data_as_uint8, k, n, ldb); + gemmlowp::MatrixMap<std::int32_t, ResultOrder> result( + output_data_as_int32, m, n, ldc); + const std::tuple<> empty_pipeline = {}; + + auto& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + TensorflowGemmContext context(worker_threads.num_threads, + worker_threads.workers); + gemmlowp::GemmWithOutputPipeline<std::uint8_t, std::int32_t, + gemmlowp::DefaultL8R8BitDepthParams>( + &context, lhs, rhs, &result, -input_offset, -filter_offset, + empty_pipeline); + } else { + ReferenceGemm<T1, T2, T3>( + transpose_a, transpose_b, transpose_c, m, n, k, im2col_buffer, + input_offset, lda, filter_data, filter_offset, ldb, output_data, + output_shift, output_offset, output_mult, ldc); + } } } }; |