diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-03-09 10:39:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-09 10:44:15 -0800 |
commit | 0ebfee36ed65f3540c216f10b8ec326b7b52db3a (patch) | |
tree | 097a26bbeec3370fefc6f0964f1db98f6d92a366 /tensorflow/contrib/lite/kernels/conv.cc | |
parent | 87dab2d8289750c9d34f26d7d5fb18475dff985b (diff) |
Make SetNumThreads apply to the eigen threads. (This creates a dependency on eigen!)
PiperOrigin-RevId: 188504172
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 10 |
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 6821a22226..b91ba1a03d 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" @@ -87,18 +88,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // to carry information from Prepare() to Eval(). auto* data = new OpData; gemm_support::IncrementUsageCounter(context); + eigen_support::IncrementUsageCounter(context); - // TODO(ahentz): This is the gemmlowp context, which really only applies to - // quantized kernels. However, Interpreter::SetNumThreads() should also be - // setting the number of kernel on Eigen, so this works OK as a proxy for - // now. - int num_threads = gemm_support::GetFromContext(context)->max_num_threads(); - data->run_multithreaded_kernel = num_threads != 1; + data->run_multithreaded_kernel = context->recommended_num_threads != 1; return data; } void Free(TfLiteContext* context, void* buffer) { + eigen_support::DecrementUsageCounter(context); gemm_support::DecrementUsageCounter(context); delete reinterpret_cast<OpData*>(buffer); } |