diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h')
-rw-r--r-- | tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h | 63 |
1 files changed, 17 insertions, 46 deletions
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index 27d9224512..4a3545d47a 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -35,35 +35,6 @@ limitations under the License. namespace tflite { namespace multithreaded_ops { -class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { - public: - explicit EigenThreadPoolWrapper(Eigen::ThreadPool* pool) : pool_(pool) {} - ~EigenThreadPoolWrapper() override {} - - void Schedule(std::function<void()> fn) override { - pool_->Schedule(std::move(fn)); - } - int NumThreads() const override { return pool_->NumThreads(); } - int CurrentThreadId() const override { return pool_->CurrentThreadId(); } - - private: - Eigen::ThreadPool* pool_ = nullptr; -}; - -// We have a single global threadpool for all convolution operations. This means -// that inferences started from different threads may block each other, but -// since the underlying resource of CPU cores should be consumed by the -// operations anyway, it shouldn't affect overall performance. -const Eigen::ThreadPoolDevice& GetThreadPoolDevice() { - const int thread_count = 4; - static Eigen::ThreadPool* tp = new Eigen::ThreadPool(thread_count); - static EigenThreadPoolWrapper* thread_pool_wrapper = - new EigenThreadPoolWrapper(tp); - static Eigen::ThreadPoolDevice* device = - new Eigen::ThreadPoolDevice(thread_pool_wrapper, thread_count); - return *device; -} - // Shorthands for the types we need when interfacing with the EigenTensor // library. typedef Eigen::TensorMap< @@ -113,14 +84,13 @@ class EigenTensorConvFunctor { } public: - void operator()(const T* input_data, T* im2col_buffer, int input_batches, - int input_height, int input_width, int input_depth, - const T* filter_data, int filter_height, int filter_width, - int filter_count, int stride_rows, int stride_cols, - int pad_width, int pad_height, TfLitePadding padding, - T* output_data, int output_height, int output_width) { - const Eigen::ThreadPoolDevice& device = GetThreadPoolDevice(); - + void operator()(const Eigen::ThreadPoolDevice& device, const T* input_data, + T* im2col_buffer, int input_batches, int input_height, + int input_width, int input_depth, const T* filter_data, + int filter_height, int filter_width, int filter_count, + int stride_rows, int stride_cols, int pad_width, + int pad_height, TfLitePadding padding, T* output_data, + int output_height, int output_width) { const bool is_1x1_kernel = (filter_height == 1 && filter_width == 1 && stride_rows == 1 && stride_cols == 1); if (is_1x1_kernel) { @@ -162,11 +132,11 @@ class EigenTensorConvFunctor { } }; -inline void Conv(const float* input_data, const Dims<4>& input_dims, - const float* filter_data, const Dims<4>& filter_dims, - const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, TfLitePadding padding, +inline void Conv(const Eigen::ThreadPoolDevice& device, const float* input_data, + const Dims<4>& input_dims, const float* filter_data, + const Dims<4>& filter_dims, const float* bias_data, + const Dims<4>& bias_dims, int stride_width, int stride_height, + int pad_width, int pad_height, TfLitePadding padding, float output_activation_min, float output_activation_max, float* output_data, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { @@ -180,10 +150,11 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, const int output_height = ArraySize(output_dims, 2); const int output_width = ArraySize(output_dims, 1); EigenTensorConvFunctor<float> conv_functor; - conv_functor(input_data, im2col_data, batches, input_height, input_width, - input_depth, filter_data, filter_height, filter_width, - output_depth, stride_height, stride_width, pad_height, pad_width, - padding, output_data, output_height, output_width); + conv_functor(device, input_data, im2col_data, batches, input_height, + input_width, input_depth, filter_data, filter_height, + filter_width, output_depth, stride_height, stride_width, + pad_height, pad_width, padding, output_data, output_height, + output_width); optimized_ops::AddBiasAndEvalActivationFunction( bias_data, bias_dims, output_data, output_dims, output_activation_min, |