aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-03-26 11:47:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-26 11:50:29 -0700
commitd2604f8dcb8a63ca063f712c24ce5aa63403b0aa (patch)
tree6109456e2238ab20f647639aa8f05c0aba5b128d /tensorflow/contrib
parent6d46c21e9f300d07e30a2185671f07d34fac3999 (diff)
Revert to initializing number of threads when SetNumThreads is called. Requiring it
to happen before OpInit() is way too confusing for users. PiperOrigin-RevId: 190499644
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/BUILD2
-rw-r--r--tensorflow/contrib/lite/interpreter.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/conv.cc5
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.cc7
-rw-r--r--tensorflow/contrib/lite/kernels/eigen_support.h3
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.cc6
-rw-r--r--tensorflow/contrib/lite/kernels/gemm_support.h3
7 files changed, 28 insertions, 4 deletions
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index dafe6f136e..18efa64507 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -133,10 +133,10 @@ cc_library(
":schema_fbs_version",
":simple_memory_arena",
":util",
+ "//tensorflow/contrib/lite/kernels:eigen_support",
"//tensorflow/contrib/lite/kernels:gemm_support",
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/schema:schema_fbs",
- "//tensorflow/core:lib_platform",
],
)
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 937c185b0a..4575fe884d 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/error_reporter.h"
#include "tensorflow/contrib/lite/graph_info.h"
+#include "tensorflow/contrib/lite/kernels/eigen_support.h"
#include "tensorflow/contrib/lite/kernels/gemm_support.h"
#include "tensorflow/contrib/lite/memory_planner.h"
#include "tensorflow/contrib/lite/nnapi_delegate.h"
@@ -762,6 +763,11 @@ void Interpreter::UseNNAPI(bool enable) {
void Interpreter::SetNumThreads(int num_threads) {
context_.recommended_num_threads = num_threads;
+
+ // TODO(ahentz): find a way to avoid this. It causes gemmlowp and eigen to
+ // be required in order to compile the framework.
+ gemm_support::SetNumThreads(&context_, num_threads);
+ eigen_support::SetNumThreads(&context_, num_threads);
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate,
diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc
index e0cd12f1b4..18ff33bf9f 100644
--- a/tensorflow/contrib/lite/kernels/conv.cc
+++ b/tensorflow/contrib/lite/kernels/conv.cc
@@ -89,9 +89,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) {
auto* data = new OpData;
gemm_support::IncrementUsageCounter(context);
eigen_support::IncrementUsageCounter(context);
-
- data->run_multithreaded_kernel = context->recommended_num_threads != 1;
-
return data;
}
@@ -176,6 +173,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
auto* params = reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
OpData* data = reinterpret_cast<OpData*>(node->user_data);
+ data->run_multithreaded_kernel = context->recommended_num_threads != 1;
+
TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node));
bool hasBias = node->inputs->size == 3;
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.cc b/tensorflow/contrib/lite/kernels/eigen_support.cc
index 213e465552..f1fdb42624 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.cc
+++ b/tensorflow/contrib/lite/kernels/eigen_support.cc
@@ -46,8 +46,15 @@ void DecrementUsageCounter(TfLiteContext* context) {
}
if (--ptr->num_references == 0) {
delete ptr;
+ context->eigen_context = nullptr;
}
}
+void SetNumThreads(TfLiteContext* context, int num_threads) {
+ IncrementUsageCounter(context);
+ Eigen::setNbThreads(num_threads);
+ DecrementUsageCounter(context);
+}
+
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/eigen_support.h b/tensorflow/contrib/lite/kernels/eigen_support.h
index d47e691123..aa8c351fd8 100644
--- a/tensorflow/contrib/lite/kernels/eigen_support.h
+++ b/tensorflow/contrib/lite/kernels/eigen_support.h
@@ -28,6 +28,9 @@ void IncrementUsageCounter(TfLiteContext* context);
// usages all temporary Eigen objects will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
+// Set the number of threads that can be used by Eigen.
+void SetNumThreads(TfLiteContext* context, int num_threads);
+
} // namespace eigen_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.cc b/tensorflow/contrib/lite/kernels/gemm_support.cc
index 76a5165d14..95f45ea768 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.cc
+++ b/tensorflow/contrib/lite/kernels/gemm_support.cc
@@ -61,5 +61,11 @@ gemmlowp::GemmContext* GetFromContext(TfLiteContext* context) {
return ptr->gemm_context_;
}
+void SetNumThreads(TfLiteContext* context, int num_threads) {
+ IncrementUsageCounter(context);
+ GetFromContext(context)->set_max_num_threads(num_threads);
+ DecrementUsageCounter(context);
+}
+
} // namespace gemm_support
} // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h
index 37af772c68..f033501cb6 100644
--- a/tensorflow/contrib/lite/kernels/gemm_support.h
+++ b/tensorflow/contrib/lite/kernels/gemm_support.h
@@ -45,6 +45,9 @@ void IncrementUsageCounter(TfLiteContext* context);
// 'context'. If there are no more usages the GemmContext will be deleted.
void DecrementUsageCounter(TfLiteContext* context);
+// Set the number of threads that can be used by gemmlowp.
+void SetNumThreads(TfLiteContext* context, int num_threads);
+
} // namespace gemm_support
} // namespace tflite