diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-09 09:39:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-09 09:43:42 -0800 |
commit | 96a7b1443f6b652c04957ac8c53d6597be434697 (patch) | |
tree | 2f65bc9a7602af847ad0aabf37a75f50d36457f1 /tensorflow/contrib/lite/kernels/conv.cc | |
parent | 7fbfa59b1d970eb5e3a27b12ef38315ab556faef (diff) |
Use the multithreaded conv only when threads are available.
PiperOrigin-RevId: 188495357
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index b93a416351..6821a22226 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -43,6 +43,8 @@ namespace conv { enum KernelType { kReference, kGenericOptimized, // Neon-free + // kMultithreadOptimized is a mixture of an Eigen-based kernel when threads + // are available and kGenericOptimized when we must use only one thread. kMultithreadOptimized, // The kernel uses use CBLAS interface for matrix multiplication. // It's fast when an optimized CBLAS implementation is available (e.g. Apple @@ -75,6 +77,8 @@ struct OpData { bool need_hwcn_weights; bool have_weights_been_transposed; bool need_im2col; + + bool run_multithreaded_kernel; }; void* Init(TfLiteContext* context, const char* buffer, size_t length) { @@ -83,6 +87,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // to carry information from Prepare() to Eval(). auto* data = new OpData; gemm_support::IncrementUsageCounter(context); + + // TODO(ahentz): This is the gemmlowp context, which really only applies to + // quantized kernels. However, Interpreter::SetNumThreads() should also be + // setting the number of kernel on Eigen, so this works OK as a proxy for + // now. + int num_threads = gemm_support::GetFromContext(context)->max_num_threads(); + data->run_multithreaded_kernel = num_threads != 1; + return data; } @@ -137,7 +149,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, // buffer to store the results. // This path is only used for float processing, so only create the buffer if // we're running with that data type. - data->need_hwcn_weights = (input->type == kTfLiteFloat32); + data->need_hwcn_weights = + (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel); int temporaries_count = 0; if (data->need_im2col) { @@ -449,8 +462,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // separate ops to avoid dispatch overhead here. switch (input->type) { // Already know in/outtypes are same. case kTfLiteFloat32: - EvalFloat<kernel_type>(context, node, params, data, input, filter, bias, - im2col, hwcn_weights, output); + if (data->run_multithreaded_kernel) { + EvalFloat<kernel_type>(context, node, params, data, input, filter, bias, + im2col, hwcn_weights, output); + } else { + EvalFloat<kGenericOptimized>(context, node, params, data, input, filter, + bias, im2col, hwcn_weights, output); + } break; case kTfLiteUInt8: EvalQuantized<kernel_type>(context, node, params, data, input, filter, |