diff options
author | 2018-03-26 11:47:50 -0700 | |
---|---|---|
committer | 2018-03-26 11:50:29 -0700 | |
commit | d2604f8dcb8a63ca063f712c24ce5aa63403b0aa (patch) | |
tree | 6109456e2238ab20f647639aa8f05c0aba5b128d /tensorflow/contrib | |
parent | 6d46c21e9f300d07e30a2185671f07d34fac3999 (diff) |
Revert to initializing number of threads when SetNumThreads is called. Requiring it
to happen before OpInit() is way too confusing for users.
PiperOrigin-RevId: 190499644
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/interpreter.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/conv.cc | 5 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.cc | 7 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/eigen_support.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/gemm_support.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/gemm_support.h | 3 |
7 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index dafe6f136e..18efa64507 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -133,10 +133,10 @@ cc_library( ":schema_fbs_version", ":simple_memory_arena", ":util", + "//tensorflow/contrib/lite/kernels:eigen_support", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/schema:schema_fbs", - "//tensorflow/core:lib_platform", ], ) diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 937c185b0a..4575fe884d 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/kernels/eigen_support.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" @@ -762,6 +763,11 @@ void Interpreter::UseNNAPI(bool enable) { void Interpreter::SetNumThreads(int num_threads) { context_.recommended_num_threads = num_threads; + + // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to + // be required in order to compile the framework. + gemm_support::SetNumThreads(&context_, num_threads); + eigen_support::SetNumThreads(&context_, num_threads); } TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate, diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index e0cd12f1b4..18ff33bf9f 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -89,9 +89,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { auto* data = new OpData; gemm_support::IncrementUsageCounter(context); eigen_support::IncrementUsageCounter(context); - - data->run_multithreaded_kernel = context->recommended_num_threads != 1; - return data; } @@ -176,6 +173,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data); OpData* data = reinterpret_cast<OpData*>(node->user_data); + data->run_multithreaded_kernel = context->recommended_num_threads != 1; + TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); bool hasBias = node->inputs->size == 3; diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc index 213e465552..f1fdb42624 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.cc +++ b/tensorflow/contrib/lite/kernels/eigen_support.cc @@ -46,8 +46,15 @@ void DecrementUsageCounter(TfLiteContext* context) { } if (--ptr->num_references == 0) { delete ptr; + context->eigen_context = nullptr; } } +void SetNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + Eigen::setNbThreads(num_threads); + DecrementUsageCounter(context); +} + } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h index d47e691123..aa8c351fd8 100644 --- a/tensorflow/contrib/lite/kernels/eigen_support.h +++ b/tensorflow/contrib/lite/kernels/eigen_support.h @@ -28,6 +28,9 @@ void IncrementUsageCounter(TfLiteContext* context); // usages all temporary Eigen objects will be deleted. void DecrementUsageCounter(TfLiteContext* context); +// Set the number of threads that can be used by Eigen. +void SetNumThreads(TfLiteContext* context, int num_threads); + } // namespace eigen_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc index 76a5165d14..95f45ea768 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.cc +++ b/tensorflow/contrib/lite/kernels/gemm_support.cc @@ -61,5 +61,11 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) { return ptr->gemm_context_; } +void SetNumThreads(TfLiteContext* context, int num_threads) { + IncrementUsageCounter(context); + GetFromContext(context)->set_max_num_threads(num_threads); + DecrementUsageCounter(context); +} + } // namespace gemm_support } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index 37af772c68..f033501cb6 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -45,6 +45,9 @@ void IncrementUsageCounter(TfLiteContext* context); // 'context'. If there are no more usages the GemmContext will be deleted. void DecrementUsageCounter(TfLiteContext* context); +// Set the number of threads that can be used by gemmlowp. +void SetNumThreads(TfLiteContext* context, int num_threads); + } // namespace gemm_support } // namespace tflite |