diff options
Diffstat (limited to 'tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java')
-rw-r--r-- | tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java | 63 |
1 files changed, 46 insertions, 17 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 587f2941e5..9b7f394258 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -55,23 +55,7 @@ public class TensorFlowInferenceInterface { * @param model The filepath to the GraphDef proto representing the model. */ public TensorFlowInferenceInterface(AssetManager assetManager, String model) { - Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); - try { - // Hack to see if the native libraries have been loaded. - new RunStats(); - Log.i(TAG, "TensorFlow native methods already loaded"); - } catch (UnsatisfiedLinkError e1) { - Log.i( - TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference"); - try { - System.loadLibrary("tensorflow_inference"); - Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"); - } catch (UnsatisfiedLinkError e2) { - throw new RuntimeException( - "Native TF methods not found; check that the correct native" - + " libraries are present in the APK."); - } - } + prepareNativeRuntime(); this.modelName = model; this.g = new Graph(); @@ -102,6 +86,31 @@ public class TensorFlowInferenceInterface { throw new RuntimeException("Failed to load model from '" + model + "'", e); } } + + /* + * Load a TensorFlow model from provided InputStream. + * Note: The InputStream will not be closed after loading model, users need to + * close it themselves. + * + * @param is The InputStream to use to load the model. + */ + public TensorFlowInferenceInterface(InputStream is) { + prepareNativeRuntime(); + + // modelName is redundant for model loading from input stream, here is for + // avoiding error in initialization as modelName is marked final. + this.modelName = ""; + this.g = new Graph(); + this.sess = new Session(g); + this.runner = sess.runner(); + + try { + loadGraph(is, g); + Log.i(TAG, "Successfully loaded model from the input stream"); + } catch (IOException e) { + throw new RuntimeException("Failed to load model from the input stream", e); + } + } /** * Runs inference between the previously registered input nodes (via feed*) and the requested @@ -408,6 +417,26 @@ public class TensorFlowInferenceInterface { public void fetch(String outputName, ByteBuffer dst) { getTensor(outputName).writeTo(dst); } + + private void prepareNativeRuntime() { + Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); + try { + // Hack to see if the native libraries have been loaded. + new RunStats(); + Log.i(TAG, "TensorFlow native methods already loaded"); + } catch (UnsatisfiedLinkError e1) { + Log.i( + TAG, "TensorFlow native methods not found, attempting to load via tensorflow_inference"); + try { + System.loadLibrary("tensorflow_inference"); + Log.i(TAG, "Successfully loaded TensorFlow native methods (RunStats error may be ignored)"); + } catch (UnsatisfiedLinkError e2) { + throw new RuntimeException( + "Native TF methods not found; check that the correct native" + + " libraries are present in the APK."); + } + } + } private void loadGraph(InputStream is, Graph g) throws IOException { final long startMs = System.currentTimeMillis(); |