aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java')
-rw-r--r--tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java63
1 files changed, 46 insertions, 17 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
index 587f2941e5..9b7f394258 100644
--- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
+++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java
@@ -55,23 +55,7 @@ public class TensorFlowInferenceInterface {
* @param model The filepath to the GraphDef proto representing the model.
*/
public TensorFlowInferenceInterface(AssetManager assetManager, String model) {
- Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
- try {
- // Hack to see if the native libraries have been loaded.
- new RunStats();
- Log.i(TAG, "TensorFlow native methods already loaded");
- } catch (UnsatisfiedLinkError e1) {
- Log.i(
- TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
- try {
- System.loadLibrary("tensorflow_inference");
- Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
- } catch (UnsatisfiedLinkError e2) {
- throw new RuntimeException(
- "Native TF methods not found; check that the correct native"
- + " libraries are present in the APK.");
- }
- }
+ prepareNativeRuntime();
this.modelName = model;
this.g = new Graph();
@@ -102,6 +86,31 @@ public class TensorFlowInferenceInterface {
throw new RuntimeException("Failed to load model from '" + model + "'", e);
}
}
+
+ /*
+ * Load a TensorFlow model from provided InputStream.
+ * Note: The InputStream will not be closed after loading model, users need to
+ * close it themselves.
+ *
+ * @param is The InputStream to use to load the model.
+ */
+ public TensorFlowInferenceInterface(InputStream is) {
+ prepareNativeRuntime();
+
+ // modelName is redundant for model loading from input stream, here is for
+ // avoiding error in initialization as modelName is marked final.
+ this.modelName = "";
+ this.g = new Graph();
+ this.sess = new Session(g);
+ this.runner = sess.runner();
+
+ try {
+ loadGraph(is, g);
+ Log.i(TAG, "Successfully loaded model from the input stream");
+ } catch (IOException e) {
+ throw new RuntimeException("Failed to load model from the input stream", e);
+ }
+ }
/**
* Runs inference between the previously registered input nodes (via feed*) and the requested
@@ -408,6 +417,26 @@ public class TensorFlowInferenceInterface {
public void fetch(String outputName, ByteBuffer dst) {
getTensor(outputName).writeTo(dst);
}
+
+ private void prepareNativeRuntime() {
+ Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded");
+ try {
+ // Hack to see if the native libraries have been loaded.
+ new RunStats();
+ Log.i(TAG, "TensorFlow native methods already loaded");
+ } catch (UnsatisfiedLinkError e1) {
+ Log.i(
+ TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference");
+ try {
+ System.loadLibrary("tensorflow_inference");
+ Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)");
+ } catch (UnsatisfiedLinkError e2) {
+ throw new RuntimeException(
+ "Native TF methods not found; check that the correct native"
+ + " libraries are present in the APK.");
+ }
+ }
+ }
private void loadGraph(InputStream is, Graph g) throws IOException {
final long startMs = System.currentTimeMillis();