From 9f2fa2ec4a68bb9e88ee20146927f84e4f9fe199 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 16 Jul 2016 06:51:50 -0800 Subject: In contrib/quantization, use eigen threadpool when calling gemmlowp to avoid creating a new one each time. Change: 127624630 --- .../quantization/kernels/quantization_utils.h | 56 ++++++++++++++++++++++ .../quantization/kernels/quantized_conv_ops.cc | 54 +++++++++++---------- .../quantization/kernels/quantized_matmul_op.cc | 35 +++++++------- 3 files changed, 104 insertions(+), 41 deletions(-) (limited to 'tensorflow/contrib/quantization') diff --git a/tensorflow/contrib/quantization/kernels/quantization_utils.h b/tensorflow/contrib/quantization/kernels/quantization_utils.h index 5c3716d24e..c9a3c77797 100644 --- a/tensorflow/contrib/quantization/kernels/quantization_utils.h +++ b/tensorflow/contrib/quantization/kernels/quantization_utils.h @@ -25,7 +25,9 @@ limitations under the License. // to avoid a dependency on floating-point hardware. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "external/gemmlowp/public/gemmlowp.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { @@ -487,6 +489,60 @@ void QuantizedAdd(const Eigen::ThreadPoolDevice& device, const Tensor& input, } } +// See gemmlowp/internal/multi_thread_gemm.h for definitions of +// Prepare, Wait, StartWorker, and CreateWorkers. +class TensorflowGemmlowpWorkersPool { + public: + TensorflowGemmlowpWorkersPool(thread::ThreadPool* workers) + : workers_(workers) {} + + void Prepare(int workers_count) { + counter_to_decrement_when_ready_.Reset(workers_count); + } + + void Wait() { counter_to_decrement_when_ready_.Wait(); } + + void StartWorker(int index, gemmlowp::Task* task) { + CHECK(workers_ != nullptr); + // is ignored - the tensorflow threadpool does not support assigning + // to a specific thread. + workers_->Schedule([this, task]() { + // TODO(cwhipkey): get a local_allocator from a thread local. + gemmlowp::Allocator local_allocator; + CHECK(task != nullptr); + task->local_allocator = &local_allocator; + task->Run(); + delete task; + counter_to_decrement_when_ready_.DecrementCount(); + }); + } + + void CreateWorkers(std::size_t workers_count) {} + + private: + thread::ThreadPool* const workers_; + + // The BlockingCounter used to wait for the workers. + gemmlowp::BlockingCounter counter_to_decrement_when_ready_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmlowpWorkersPool); +}; + +class TensorflowGemmContext : public gemmlowp::MultiThreadGemmContextBase { + public: + TensorflowGemmContext(int num_threads, thread::ThreadPool* workers) + : workers_pool_(workers) { + set_max_num_threads(num_threads); + } + + TensorflowGemmlowpWorkersPool* workers_pool() { return &workers_pool_; } + + private: + TensorflowGemmlowpWorkersPool workers_pool_; + + TF_DISALLOW_COPY_AND_ASSIGN(TensorflowGemmContext); +}; + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_QUANTIZATION_KERNELS_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc b/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc index 46ec0a337a..647e68ea12 100644 --- a/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc +++ b/tensorflow/contrib/quantization/kernels/quantized_conv_ops.cc @@ -46,13 +46,13 @@ namespace tensorflow { template class ReferenceConvFunctor { public: - void operator()(const T1* input_data, int input_batches, int input_height, - int input_width, int input_depth, int input_offset, - const T2* filter_data, int filter_height, int filter_width, - int filter_count, int filter_offset, int stride, - Padding padding, T3* output_data, int output_height, - int output_width, int output_shift, int output_offset, - int output_mult) { + void operator()(OpKernelContext* op_context, const T1* input_data, + int input_batches, int input_height, int input_width, + int input_depth, int input_offset, const T2* filter_data, + int filter_height, int filter_width, int filter_count, + int filter_offset, int stride, Padding padding, + T3* output_data, int output_height, int output_width, + int output_shift, int output_offset, int output_mult) { // Set up some constants we need for the output down-shifting and // saturation. const int32 highest = static_cast(Eigen::NumTraits::highest()); @@ -186,13 +186,13 @@ class ReferenceConvFunctor { template class Im2ColConvFunctor { public: - void operator()(const T1* input_data, int input_batches, int input_height, - int input_width, int input_depth, int input_offset, - const T2* filter_data, int filter_height, int filter_width, - int filter_count, int filter_offset, int stride, - Padding padding, T3* output_data, int output_height, - int output_width, int output_shift, int output_offset, - int output_mult) { + void operator()(OpKernelContext* op_context, const T1* input_data, + int input_batches, int input_height, int input_width, + int input_depth, int input_offset, const T2* filter_data, + int filter_height, int filter_width, int filter_count, + int filter_offset, int stride, Padding padding, + T3* output_data, int output_height, int output_width, + int output_shift, int output_offset, int output_mult) { if (input_offset < 0) { // Only log the first few occurrences of this warning. static int warning_count = 0; @@ -206,11 +206,11 @@ class Im2ColConvFunctor { << " avoid this situation."; } ReferenceConvFunctor conv_functor; - conv_functor(input_data, input_batches, input_height, input_width, - input_depth, input_offset, filter_data, filter_height, - filter_width, filter_count, filter_offset, stride, padding, - output_data, output_height, output_width, output_shift, - output_offset, output_mult); + conv_functor(op_context, input_data, input_batches, input_height, + input_width, input_depth, input_offset, filter_data, + filter_height, filter_width, filter_count, filter_offset, + stride, padding, output_data, output_height, output_width, + output_shift, output_offset, output_mult); return; } @@ -369,7 +369,11 @@ class Im2ColConvFunctor { gemmlowp::MatrixMap result( output_data_as_int32, m, n, ldc); const std::tuple<> empty_pipeline = {}; - gemmlowp::GemmContext context; + + auto& worker_threads = + *(op_context->device()->tensorflow_cpu_worker_threads()); + TensorflowGemmContext context(worker_threads.num_threads, + worker_threads.workers); gemmlowp::GemmWithOutputPipeline( &context, lhs, rhs, &result, -input_offset, -filter_offset, @@ -483,11 +487,11 @@ class QuantizedConv2DOp : public OpKernel { // This will call different implementations (e.g. reference or optimized) // depending on the template parameter. ConvFunctor conv_functor; - conv_functor(input.flat().data(), batch, input_rows, input_cols, - in_depth, offset_input, filter.flat().data(), filter_rows, - filter_cols, out_depth, offset_filter, stride, padding_, - output->flat().data(), out_rows, out_cols, shift_output, - offset_output, mult_output); + conv_functor(context, input.flat().data(), batch, input_rows, + input_cols, in_depth, offset_input, filter.flat().data(), + filter_rows, filter_cols, out_depth, offset_filter, stride, + padding_, output->flat().data(), out_rows, out_cols, + shift_output, offset_output, mult_output); float min_output_value; float max_output_value; diff --git a/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc b/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc index 856ee8c2f4..21abce932a 100644 --- a/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc +++ b/tensorflow/contrib/quantization/kernels/quantized_matmul_op.cc @@ -28,9 +28,9 @@ namespace tensorflow { // combinations of transpose attributes we need to support, and they have to be // compile-time constants to work with the templates used internally. template -void GemmlowpMultiply(const quint8* a_data, const quint8* b_data, - qint32* c_data, int m, int n, int k, int offset_a, - int offset_b, int lda, int ldb, int ldc) { +void GemmlowpMultiply(OpKernelContext* op_context, const quint8* a_data, + const quint8* b_data, qint32* c_data, int m, int n, int k, + int offset_a, int offset_b, int lda, int ldb, int ldc) { const uint8* a_data_as_uint8 = &(a_data->value); const uint8* b_data_as_uint8 = &(b_data->value); int32* c_data_as_int32 = &(c_data->value); @@ -47,7 +47,10 @@ void GemmlowpMultiply(const quint8* a_data, const quint8* b_data, gemmlowp::MatrixMap result(c_data_as_int32, m, n, ldc); const std::tuple<> empty_pipeline = {}; - gemmlowp::GemmContext context; + auto& worker_threads = + *(op_context->device()->tensorflow_cpu_worker_threads()); + TensorflowGemmContext context(worker_threads.num_threads, + worker_threads.workers); gemmlowp::GemmWithOutputPipeline( &context, lhs, rhs, &result, -offset_a, -offset_b, empty_pipeline); @@ -130,23 +133,23 @@ class QuantizedMatMulOp : public OpKernel { (shift_c == 0) && (transpose_c == false)) { if (transpose_a_) { if (transpose_b_) { - GemmlowpMultiply(a_data, b_data, c_data, m, n, k, - offset_a, offset_b, lda, ldb, - ldc); + GemmlowpMultiply(context, a_data, b_data, c_data, + m, n, k, offset_a, offset_b, lda, + ldb, ldc); } else { - GemmlowpMultiply(a_data, b_data, c_data, m, n, k, - offset_a, offset_b, lda, ldb, - ldc); + GemmlowpMultiply(context, a_data, b_data, c_data, + m, n, k, offset_a, offset_b, lda, + ldb, ldc); } } else { if (transpose_b_) { - GemmlowpMultiply(a_data, b_data, c_data, m, n, k, - offset_a, offset_b, lda, ldb, - ldc); + GemmlowpMultiply(context, a_data, b_data, c_data, + m, n, k, offset_a, offset_b, lda, + ldb, ldc); } else { - GemmlowpMultiply(a_data, b_data, c_data, m, n, k, - offset_a, offset_b, lda, ldb, - ldc); + GemmlowpMultiply(context, a_data, b_data, c_data, + m, n, k, offset_a, offset_b, + lda, ldb, ldc); } } } else { -- cgit v1.2.3