aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/conv.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 09:39:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 09:43:42 -0800
commit96a7b1443f6b652c04957ac8c53d6597be434697 (patch)
tree2f65bc9a7602af847ad0aabf37a75f50d36457f1 /tensorflow/contrib/lite/kernels/conv.cc
parent7fbfa59b1d970eb5e3a27b12ef38315ab556faef (diff)
Use the multithreaded conv only when threads are available.
PiperOrigin-RevId: 188495357
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc24
1 files changed, 21 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index b93a416351..6821a22226 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -43,6 +43,8 @@ namespace conv {
enum KernelType {
kReference,
kGenericOptimized, // Neon-free
+ // kMultithreadOptimized is a mixture of an Eigen-based kernel when threads
+ // are available and kGenericOptimized when we must use only one thread.
kMultithreadOptimized,
// The kernel uses use CBLAS interface for matrix multiplication.
// It's fast when an optimized CBLAS implementation is available (e.g. Apple
@@ -75,6 +77,8 @@ struct OpData {
bool need_hwcn_weights;
bool have_weights_been_transposed;
bool need_im2col;
+
+ bool run_multithreaded_kernel;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -83,6 +87,14 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// to carry information from Prepare() to Eval().
auto* data = new OpData;
gemm_support::IncrementUsageCounter(context);
+
+ // TODO(ahentz): This is the gemmlowp context, which really only applies to
+ // quantized kernels. However, Interpreter::SetNumThreads() should also be
+ // setting the number of kernel on Eigen, so this works OK as a proxy for
+ // now.
+ int num_threads = gemm_support::GetFromContext(context)->max_num_threads();
+ data->run_multithreaded_kernel = num_threads != 1;
+
return data;
}
@@ -137,7 +149,8 @@ static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context,
// buffer to store the results.
// This path is only used for float processing, so only create the buffer if
// we're running with that data type.
- data->need_hwcn_weights = (input->type == kTfLiteFloat32);
+ data->need_hwcn_weights =
+ (input->type == kTfLiteFloat32 && data->run_multithreaded_kernel);
int temporaries_count = 0;
if (data->need_im2col) {
@@ -449,8 +462,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
// separate ops to avoid dispatch overhead here.
switch (input->type) { // Already know in/outtypes are same.
case kTfLiteFloat32:
- EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
- im2col, hwcn_weights, output);
+ if (data->run_multithreaded_kernel) {
+ EvalFloat<kernel_type>(context, node, params, data, input, filter, bias,
+ im2col, hwcn_weights, output);
+ } else {
+ EvalFloat<kGenericOptimized>(context, node, params, data, input, filter,
+ bias, im2col, hwcn_weights, output);
+ }
break;
case kTfLiteUInt8:
EvalQuantized<kernel_type>(context, node, params, data, input, filter,