aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/java/src
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/java/src')
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java7
-rw-r--r--tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java6
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc10
-rw-r--r--tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h12
4 files changed, 34 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
index e915e65aa1..e84ee71129 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java
@@ -215,6 +215,13 @@ public final class Interpreter implements AutoCloseable {
}
}
+ public void setNumThreads(int num_threads) {
+ if (wrapper == null) {
+ throw new IllegalStateException("The interpreter has already been closed.");
+ }
+ wrapper.setNumThreads(num_threads);
+ }
+
/** Release resources associated with the {@code Interpreter}. */
@Override
public void close() {
diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
index dfc8ac111a..2fc803715b 100644
--- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
+++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java
@@ -153,6 +153,10 @@ final class NativeInterpreterWrapper implements AutoCloseable {
useNNAPI(interpreterHandle, useNNAPI);
}
+ void setNumThreads(int num_threads) {
+ numThreads(interpreterHandle, num_threads);
+ }
+
/** Gets index of an input given its name. */
int getInputIndex(String name) {
if (inputsIndexes == null) {
@@ -324,6 +328,8 @@ final class NativeInterpreterWrapper implements AutoCloseable {
private static native void useNNAPI(long interpreterHandle, boolean state);
+ private static native void numThreads(long interpreterHandle, int num_threads);
+
private static native long createErrorReporter(int size);
private static native long createModel(String modelPathOrBuffer, long errorHandle);
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
index ccfdfd829b..45f510da1d 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc
@@ -320,6 +320,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
interpreter->UseNNAPI(static_cast<bool>(state));
}
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint num_threads) {
+ tflite::Interpreter* interpreter = convertLongToInterpreter(env, handle);
+ if (interpreter == nullptr) return;
+ interpreter->SetNumThreads(static_cast<int>(num_threads));
+}
+
JNIEXPORT jlong JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_createErrorReporter(
JNIEnv* env, jclass clazz, jint size) {
diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
index 0e28a77fee..eaa765cb34 100644
--- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
+++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h
@@ -61,7 +61,7 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_getOutputNames(JNIEnv* env,
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
- * Signature: (JZ)
+ * Signature: (JZ)V
*/
JNIEXPORT void JNICALL
Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
@@ -72,6 +72,16 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_useNNAPI(JNIEnv* env,
/*
* Class: org_tensorflow_lite_NativeInterpreterWrapper
* Method:
+ * Signature: (JI)V
+ */
+JNIEXPORT void JNICALL
+Java_org_tensorflow_lite_NativeInterpreterWrapper_numThreads(JNIEnv* env,
+ jclass clazz,
+ jlong handle,
+ jint num_threads);
+/*
+ * Class: org_tensorflow_lite_NativeInterpreterWrapper
+ * Method:
* Signature: (I)J
*/
JNIEXPORT jlong JNICALL