diff options
author | Andrew Harp <andrewharp@google.com> | 2017-03-23 14:20:36 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-03-23 15:39:26 -0700 |
commit | 2de2d52164e2b3d0ddb69b7c1037452419de25ce (patch) | |
tree | 01db062c17fe85bcecbd560faa12ec89b58c204e | |
parent | df70d61ea2a76fc226237ef663e94245fb523376 (diff) |
Android: update TensorFlowInferenceInterface: replace error code returns with exceptions, use longs for dimensions, and simplify method signatures where possible.
Note this is a breaking change and will cause the TF demo in tensorflow/examples/android to not work with older jars/.so binaries built from tensorflow/contrib/android
Change: 151061702
5 files changed, 134 insertions, 197 deletions
diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 713f63b806..1f180429b2 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import org.tensorflow.DataType; import org.tensorflow.Graph; +import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; @@ -45,7 +46,13 @@ public class TensorFlowInferenceInterface { private static final String TAG = "TensorFlowInferenceInterface"; private static final String ASSET_FILE_PREFIX = "file:///android_asset/"; - public TensorFlowInferenceInterface() { + /* + * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. + * + * @param assetManager The AssetManager to use to load the model file. + * @param model The filepath to the GraphDef proto representing the model. + */ + public TensorFlowInferenceInterface(AssetManager assetManager, String model) { Log.i(TAG, "Checking to see if TensorFlow native methods are already loaded"); try { // Hack to see if the native libraries have been loaded. @@ -63,16 +70,12 @@ public class TensorFlowInferenceInterface { + " libraries are present in the APK."); } } - } - /** - * Load a TensorFlow model from the AssetManager or from disk if it is not an asset file. - * - * @param assetManager The AssetManager to use to load the model file. - * @param model The filepath to the GraphDef proto representing the model. - * @return 0 on success. - */ - public int initializeTensorFlow(AssetManager assetManager, String model) { + this.modelName = model; + this.g = new Graph(); + this.sess = new Session(g); + this.runner = sess.runner(); + final boolean hasAssetPrefix = model.startsWith(ASSET_FILE_PREFIX); InputStream is = null; try { @@ -80,37 +83,42 @@ public class TensorFlowInferenceInterface { is = assetManager.open(aname); } catch (IOException e) { if (hasAssetPrefix) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } // Perhaps the model file is not an asset but is on disk. try { is = new FileInputStream(model); } catch (IOException e2) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e2.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } } try { - load(is); + loadGraph(is, g); is.close(); Log.i(TAG, "Successfully loaded model from '" + model + "'"); - return 0; } catch (IOException e) { - Log.e(TAG, "Failed to load model from '" + model + "': " + e.toString()); - return 1; + throw new RuntimeException("Failed to load model from '" + model + "'", e); } } /** - * Runs inference between the previously registered input nodes (via fillNode*) and the requested - * output nodes. Output nodes can then be queried with the readNode* methods. + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. * * @param outputNames A list of output nodes which should be filled by the inference pass. - * @return 0 on success. */ - public int runInference(String[] outputNames) { - // Release any Tensors from the previous runInference calls. + public void run(String[] outputNames) { + run(outputNames, false); + } + + /** + * Runs inference between the previously registered input nodes (via feed*) and the requested + * output nodes. Output nodes can then be queried with the fetch* methods. + * + * @param outputNames A list of output nodes which should be filled by the inference pass. + */ + public void run(String[] outputNames, boolean enableStats) { + // Release any Tensors from the previous run calls. closeFetches(); // Add fetches. @@ -125,6 +133,10 @@ public class TensorFlowInferenceInterface { if (enableStats) { Session.Run r = runner.setOptions(RunStats.runOptions()).runAndFetchMetadata(); fetchTensors = r.outputs; + + if (runStats == null) { + runStats = new RunStats(); + } runStats.add(r.metadata); } else { fetchTensors = runner.run(); @@ -139,16 +151,13 @@ public class TensorFlowInferenceInterface { + "], outputs:[" + TextUtils.join(", ", fetchNames) + "]"); - Log.e(TAG, "Inference exception: " + e.toString()); - return -1; + throw e; } finally { - // Always release the feeds (to save resources) and reset the runner, this runInference is + // Always release the feeds (to save resources) and reset the runner, this run is // over. closeFeeds(); runner = sess.runner(); } - - return 0; } /** Returns a reference to the Graph describing the computation run during inference. */ @@ -156,15 +165,13 @@ public class TensorFlowInferenceInterface { return g; } - /** - * Whether to collect stats during inference. This should only be enabled when needed, as it will - * add overhead. - */ - public void enableStatLogging(boolean enabled) { - enableStats = enabled; - if (enableStats && runStats == null) { - runStats = new RunStats(); + public Operation graphOperation(String operationName) { + final Operation operation = g.operation(operationName); + if (operation == null) { + throw new RuntimeException( + "Node '" + operationName + "' does not exist in model '" + modelName + "'"); } + return operation; } /** Returns the last stat summary string if logging is enabled. */ @@ -185,7 +192,15 @@ public class TensorFlowInferenceInterface { runStats.close(); } runStats = null; - enableStats = false; + } + + @Override + protected void finalize() throws Throwable { + try { + close(); + } finally { + super.finalize(); + } } // Methods for taking a native Tensor and filling it with values from Java arrays. @@ -196,8 +211,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFloat(String inputName, int[] dims, float[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), FloatBuffer.wrap(src))); + public void feed(String inputName, float[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, FloatBuffer.wrap(src))); } /** @@ -206,8 +221,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeInt(String inputName, int[] dims, int[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), IntBuffer.wrap(src))); + public void feed(String inputName, int[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, IntBuffer.wrap(src))); } /** @@ -216,8 +231,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeDouble(String inputName, int[] dims, double[] src) { - addFeed(inputName, Tensor.create(mkDims(dims), DoubleBuffer.wrap(src))); + public void feed(String inputName, double[] src, long... dims) { + addFeed(inputName, Tensor.create(dims, DoubleBuffer.wrap(src))); } /** @@ -226,8 +241,8 @@ public class TensorFlowInferenceInterface { * as many elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeByte(String inputName, int[] dims, byte[] src) { - addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), ByteBuffer.wrap(src))); + public void feed(String inputName, byte[] src, long... dims) { + addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src))); } // Methods for taking a native Tensor and filling it with src from Java native IO buffers. @@ -239,8 +254,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromFloatBuffer(String inputName, IntBuffer dims, FloatBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, FloatBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -250,8 +265,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromIntBuffer(String inputName, IntBuffer dims, IntBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, IntBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -261,8 +276,8 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromDoubleBuffer(String inputName, IntBuffer dims, DoubleBuffer src) { - addFeed(inputName, Tensor.create(mkDims(dims), src)); + public void feed(String inputName, DoubleBuffer src, long... dims) { + addFeed(inputName, Tensor.create(dims, src)); } /** @@ -272,52 +287,44 @@ public class TensorFlowInferenceInterface { * elements as that of the destination Tensor. If {@link src} has more elements than the * destination has capacity, the copy is truncated. */ - public void fillNodeFromByteBuffer(String inputName, IntBuffer dims, ByteBuffer src) { - addFeed(inputName, Tensor.create(DataType.UINT8, mkDims(dims), src)); + public void feed(String inputName, ByteBuffer src, long... dims) { + addFeed(inputName, Tensor.create(DataType.UINT8, dims, src)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeFloat(String outputName, float[] dst) { - return readNodeIntoFloatBuffer(outputName, FloatBuffer.wrap(dst)); + public void fetch(String outputName, float[] dst) { + fetch(outputName, FloatBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeInt(String outputName, int[] dst) { - return readNodeIntoIntBuffer(outputName, IntBuffer.wrap(dst)); + public void fetch(String outputName, int[] dst) { + fetch(outputName, IntBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeDouble(String outputName, double[] dst) { - return readNodeIntoDoubleBuffer(outputName, DoubleBuffer.wrap(dst)); + public void fetch(String outputName, double[] dst) { + fetch(outputName, DoubleBuffer.wrap(dst)); } /** * Read from a Tensor named {@link outputName} and copy the contents into a Java array. {@link * dst} must have length greater than or equal to that of the source Tensor. This operation will * not affect dst's content past the source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeByte(String outputName, byte[] dst) { - return readNodeIntoByteBuffer(outputName, ByteBuffer.wrap(dst)); + public void fetch(String outputName, byte[] dst) { + fetch(outputName, ByteBuffer.wrap(dst)); } /** @@ -325,16 +332,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoFloatBuffer(String outputName, FloatBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, FloatBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -342,16 +342,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoIntBuffer(String outputName, IntBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, IntBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -359,16 +352,9 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoDoubleBuffer(String outputName, DoubleBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, DoubleBuffer dst) { + getTensor(outputName).writeTo(dst); } /** @@ -376,22 +362,12 @@ public class TensorFlowInferenceInterface { * <b>native ordered</b> java.nio buffer {@link dst}. {@link dst} must have capacity greater than * or equal to that of the source Tensor. This operation will not affect dst's content past the * source Tensor's size. - * - * @return 0 on success, -1 on failure. */ - public int readNodeIntoByteBuffer(String outputName, ByteBuffer dst) { - Tensor t = getTensor(outputName); - if (t == null) { - return -1; - } - t.writeTo(dst); - return 0; + public void fetch(String outputName, ByteBuffer dst) { + getTensor(outputName).writeTo(dst); } - private void load(InputStream is) throws IOException { - this.g = new Graph(); - this.sess = new Session(g); - this.runner = sess.runner(); + private void loadGraph(InputStream is, Graph g) throws IOException { final long startMs = System.currentTimeMillis(); Trace.beginSection("initializeTensorFlow"); @@ -425,26 +401,6 @@ public class TensorFlowInferenceInterface { "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); } - // The TensorFlowInferenceInterface API used int[] for dims, but the underlying TensorFlow runtime - // allows for 64-bit dimension sizes, so it needs to be converted to a long[] - private long[] mkDims(int[] dims) { - long[] ret = new long[dims.length]; - for (int i = 0; i < dims.length; ++i) { - ret[i] = (long) dims[i]; - } - return ret; - } - - // Similar to mkDims(int[]), with the shape provided in an IntBuffer. - private long[] mkDims(IntBuffer dims) { - if (dims.hasArray()) { - return mkDims(dims.array()); - } - int[] copy = new int[dims.remaining()]; - dims.duplicate().get(copy); - return mkDims(copy); - } - private void addFeed(String inputName, Tensor t) { // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. TensorId tid = TensorId.parse(inputName); @@ -485,9 +441,10 @@ public class TensorFlowInferenceInterface { if (n.equals(outputName)) { return fetchTensors.get(i); } - i++; + ++i; } - return null; + throw new RuntimeException( + "Node '" + outputName + "' was not provided to run(), so it cannot be read"); } private void closeFeeds() { @@ -507,10 +464,11 @@ public class TensorFlowInferenceInterface { } // State immutable between initializeTensorFlow calls. - private Graph g; - private Session sess; + private final String modelName; + private final Graph g; + private final Session sess; - // State reset on every call to runInference. + // State reset on every call to run. private Session.Runner runner; private List<String> feedNames = new ArrayList<String>(); private List<Tensor> feedTensors = new ArrayList<Tensor>(); @@ -518,6 +476,5 @@ public class TensorFlowInferenceInterface { private List<Tensor> fetchTensors = new ArrayList<Tensor>(); // Mutable state. - private boolean enableStats; private RunStats runStats; } diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java index c1a893e9ee..d0df1a4483 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/StylizeActivity.java @@ -369,8 +369,7 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL borderedText = new BorderedText(textSizePx); borderedText.setTypeface(Typeface.MONOSPACE); - inferenceInterface = new TensorFlowInferenceInterface(); - inferenceInterface.initializeTensorFlow(getAssets(), MODEL_FILE); + inferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODEL_FILE); previewWidth = size.getWidth(); previewHeight = size.getHeight(); @@ -585,12 +584,12 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL } // Copy the input data into TensorFlow. - inferenceInterface.fillNodeFloat( - INPUT_NODE, new int[] {1, bitmap.getWidth(), bitmap.getHeight(), 3}, floatValues); - inferenceInterface.fillNodeFloat(STYLE_NODE, new int[] {NUM_STYLES}, styleVals); + inferenceInterface.feed( + INPUT_NODE, floatValues, 1, bitmap.getWidth(), bitmap.getHeight(), 3); + inferenceInterface.feed(STYLE_NODE, styleVals, NUM_STYLES); - inferenceInterface.runInference(new String[] {OUTPUT_NODE}); - inferenceInterface.readNodeFloat(OUTPUT_NODE, floatValues); + inferenceInterface.run(new String[] {OUTPUT_NODE}, isDebug()); + inferenceInterface.fetch(OUTPUT_NODE, floatValues); for (int i = 0; i < intValues.length; ++i) { intValues[i] = @@ -603,11 +602,6 @@ public class StylizeActivity extends CameraActivity implements OnImageAvailableL bitmap.setPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight()); } - @Override - public void onSetDebug(final boolean debug) { - inferenceInterface.enableStatLogging(debug); - } - private void renderDebug(final Canvas canvas) { // TODO(andrewharp): move result display to its own View instead of using debug overlay. final Bitmap texture = textureCopyBitmap; diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java index 9f80665ee6..f660178ebe 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowImageClassifier.java @@ -56,6 +56,8 @@ public class TensorFlowImageClassifier implements Classifier { private float[] outputs; private String[] outputNames; + private boolean logStats = false; + private TensorFlowInferenceInterface inferenceInterface; private TensorFlowImageClassifier() {} @@ -102,16 +104,10 @@ public class TensorFlowImageClassifier implements Classifier { throw new RuntimeException("Problem reading label file!" , e); } - c.inferenceInterface = new TensorFlowInferenceInterface(); - if (c.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) { - throw new RuntimeException("TF initialization failed"); - } + c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); + // The shape of the output is [N, NUM_CLASSES], where N is the batch size. - final Operation operation = c.inferenceInterface.graph().operation(outputName); - if (operation == null) { - throw new RuntimeException("Node '" + outputName + "' does not exist in model '" - + modelFilename + "'"); - } + final Operation operation = c.inferenceInterface.graphOperation(outputName); final int numClasses = (int) operation.output(0).shape().size(1); Log.i(TAG, "Read " + c.labels.size() + " labels, output layer size is " + numClasses); @@ -149,19 +145,18 @@ public class TensorFlowImageClassifier implements Classifier { Trace.endSection(); // Copy the input data into TensorFlow. - Trace.beginSection("fillNodeFloat"); - inferenceInterface.fillNodeFloat( - inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); + Trace.beginSection("feed"); + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); Trace.endSection(); // Run the inference call. - Trace.beginSection("runInference"); - inferenceInterface.runInference(outputNames); + Trace.beginSection("run"); + inferenceInterface.run(outputNames, logStats); Trace.endSection(); // Copy the output Tensor back into the output array. - Trace.beginSection("readNodeFloat"); - inferenceInterface.readNodeFloat(outputName, outputs); + Trace.beginSection("fetch"); + inferenceInterface.fetch(outputName, outputs); Trace.endSection(); // Find the best classifications. @@ -192,8 +187,8 @@ public class TensorFlowImageClassifier implements Classifier { } @Override - public void enableStatLogging(boolean debug) { - inferenceInterface.enableStatLogging(debug); + public void enableStatLogging(boolean logStats) { + this.logStats = logStats; } @Override diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java index 2ec9476ef9..f3e7114335 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowMultiBoxDetector.java @@ -62,6 +62,8 @@ public class TensorFlowMultiBoxDetector implements Classifier { private String[] outputNames; private int numLocations; + private boolean logStats = false; + private TensorFlowInferenceInterface inferenceInterface; private float[] boxPriors; @@ -89,10 +91,7 @@ public class TensorFlowMultiBoxDetector implements Classifier { final String outputScoresName) { final TensorFlowMultiBoxDetector d = new TensorFlowMultiBoxDetector(); - d.inferenceInterface = new TensorFlowInferenceInterface(); - if (d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename) != 0) { - throw new RuntimeException("TF initialization failed"); - } + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); final Graph g = d.inferenceInterface.graph(); @@ -222,22 +221,21 @@ public class TensorFlowMultiBoxDetector implements Classifier { Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. - Trace.beginSection("fillNodeFloat"); - inferenceInterface.fillNodeFloat( - inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); + Trace.beginSection("feed"); + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); Trace.endSection(); // Run the inference call. - Trace.beginSection("runInference"); - inferenceInterface.runInference(outputNames); + Trace.beginSection("run"); + inferenceInterface.run(outputNames, logStats); Trace.endSection(); // Copy the output Tensor back into the output array. - Trace.beginSection("readNodeFloat"); + Trace.beginSection("fetch"); final float[] outputScoresEncoding = new float[numLocations]; final float[] outputLocationsEncoding = new float[numLocations * 4]; - inferenceInterface.readNodeFloat(outputNames[0], outputLocationsEncoding); - inferenceInterface.readNodeFloat(outputNames[1], outputScoresEncoding); + inferenceInterface.fetch(outputNames[0], outputLocationsEncoding); + inferenceInterface.fetch(outputNames[1], outputScoresEncoding); Trace.endSection(); outputLocations = decodeLocationsEncoding(outputLocationsEncoding); @@ -275,8 +273,8 @@ public class TensorFlowMultiBoxDetector implements Classifier { } @Override - public void enableStatLogging(final boolean debug) { - inferenceInterface.enableStatLogging(debug); + public void enableStatLogging(final boolean logStats) { + this.logStats = logStats; } @Override diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java index 062e8cc8d0..174723071d 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/TensorFlowYoloDetector.java @@ -86,6 +86,8 @@ public class TensorFlowYoloDetector implements Classifier { private int blockSize; + private boolean logStats = false; + private TensorFlowInferenceInterface inferenceInterface; /** Initializes a native TensorFlow session for classifying images. */ @@ -106,13 +108,8 @@ public class TensorFlowYoloDetector implements Classifier { d.floatValues = new float[inputSize * inputSize * 3]; d.blockSize = blockSize; - d.inferenceInterface = new TensorFlowInferenceInterface(); + d.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename); - final int status = d.inferenceInterface.initializeTensorFlow(assetManager, modelFilename); - if (status != 0) { - LOGGER.e("TF init status: " + status); - throw new RuntimeException("TF init status (" + status + ") != 0"); - } return d; } @@ -157,30 +154,26 @@ public class TensorFlowYoloDetector implements Classifier { Trace.endSection(); // preprocessBitmap // Copy the input data into TensorFlow. - Trace.beginSection("fillNodeFloat"); - inferenceInterface.fillNodeFloat( - inputName, new int[] {1, inputSize, inputSize, 3}, floatValues); + Trace.beginSection("feed"); + inferenceInterface.feed(inputName, floatValues, 1, inputSize, inputSize, 3); Trace.endSection(); timer.endSplit("ready for inference"); // Run the inference call. - Trace.beginSection("runInference"); - final int resultCode = inferenceInterface.runInference(outputNames); - if (resultCode != 0) { - throw new RuntimeException("Bad result code from inference: " + resultCode); - } + Trace.beginSection("run"); + inferenceInterface.run(outputNames, logStats); Trace.endSection(); timer.endSplit("ran inference"); // Copy the output Tensor back into the output array. - Trace.beginSection("readNodeFloat"); + Trace.beginSection("fetch"); final int gridWidth = bitmap.getWidth() / blockSize; final int gridHeight = bitmap.getHeight() / blockSize; final float[] output = new float[gridWidth * gridHeight * (NUM_CLASSES + 5) * NUM_BOXES_PER_BLOCK]; - inferenceInterface.readNodeFloat(outputNames[0], output); + inferenceInterface.fetch(outputNames[0], output); Trace.endSection(); // Find the best detections. @@ -256,8 +249,8 @@ public class TensorFlowYoloDetector implements Classifier { } @Override - public void enableStatLogging(final boolean debug) { - inferenceInterface.enableStatLogging(debug); + public void enableStatLogging(final boolean logStats) { + this.logStats = logStats; } @Override |