aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/conv.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-09 10:39:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-09 10:44:15 -0800
commit0ebfee36ed65f3540c216f10b8ec326b7b52db3a (patch)
tree097a26bbeec3370fefc6f0964f1db98f6d92a366 /tensorflow/contrib/lite/kernels/conv.cc
parent87dab2d8289750c9d34f26d7d5fb18475dff985b (diff)
Make SetNumThreads apply to the eigen threads. (This creates a dependency on eigen!)
PiperOrigin-RevId: 188504172
Diffstat (limited to 'tensorflow/contrib/lite/kernels/conv.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc10
1 files changed, 4 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index 6821a22226..b91ba1a03d 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/builtin_op_data.h"
#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h"
#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h"
@@ -87,18 +88,15 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
// to carry information from Prepare() to Eval().
auto* data = new OpData;
gemm_support::IncrementUsageCounter(context);
+ eigen_support::IncrementUsageCounter(context);
- // TODO(ahentz): This is the gemmlowp context, which really only applies to
- // quantized kernels. However, Interpreter::SetNumThreads() should also be
- // setting the number of kernel on Eigen, so this works OK as a proxy for
- // now.
- int num_threads = gemm_support::GetFromContext(context)->max_num_threads();
- data->run_multithreaded_kernel = num_threads != 1;
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
return data;
}
void Free(TfLiteContext* context, void* buffer) {
+ eigen_support::DecrementUsageCounter(context);
gemm_support::DecrementUsageCounter(context);
delete reinterpret_cast<OpData*>(buffer);
}